diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 1657c912c3..c559d08fd7 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -4,8 +4,8 @@ # in the repo. Unless a later match takes precedence, # @deepset-ai/core-engineering will be requested for review # when someone opens a pull request. -* @deepset-ai/core-engineering +* @deepset-ai/open-source-engineering # Documentation -*.md @deepset-ai/documentation @deepset-ai/core-engineering -releasenotes/notes/* @deepset-ai/documentation @deepset-ai/core-engineering +*.md @deepset-ai/documentation @deepset-ai/open-source-engineering +releasenotes/notes/* @deepset-ai/documentation @deepset-ai/open-source-engineering diff --git a/.github/labeler.yml b/.github/labeler.yml index 03b320d9c4..5d926787cf 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -2,11 +2,6 @@ proposal: - proposals/text/* -# 2.x -2.x: -- haystack/preview/**/* -- test/preview/**/* - # Topics topic:tests: - test/**/* diff --git a/.github/utils/pydoc-markdown.sh b/.github/utils/pydoc-markdown.sh index 39bc6dd78e..4322da034e 100755 --- a/.github/utils/pydoc-markdown.sh +++ b/.github/utils/pydoc-markdown.sh @@ -9,11 +9,3 @@ for file in ../config/* ; do echo "Converting $file..." pydoc-markdown "$file" done -# render preview markdown docs -cd .. -rm -rf temp-preview && mkdir temp-preview -cd temp-preview -for file in ../config-preview/* ; do - echo "Converting $file..." - pydoc-markdown "$file" -done diff --git a/.github/workflows/ci_metrics.yml b/.github/workflows/ci_metrics.yml index 31a329e334..688f3b9e1b 100644 --- a/.github/workflows/ci_metrics.yml +++ b/.github/workflows/ci_metrics.yml @@ -4,10 +4,8 @@ on: workflow_run: workflows: - "end-to-end" - - "end-to-end (Preview)" - "Linting" - "Tests" - - "Tests (Preview)" - "REST API Tests" types: - completed diff --git a/.github/workflows/docker_release.yml b/.github/workflows/docker_release.yml index b9f2e8f60c..27f31bb0ba 100644 --- a/.github/workflows/docker_release.yml +++ b/.github/workflows/docker_release.yml @@ -4,7 +4,7 @@ on: workflow_dispatch: push: branches: - - main + - v1.x tags: - "v[0-9].[0-9]+.[0-9]+*" diff --git a/.github/workflows/e2e.yml b/.github/workflows/e2e.yml index af42e93542..5512f571bd 100644 --- a/.github/workflows/e2e.yml +++ b/.github/workflows/e2e.yml @@ -13,7 +13,6 @@ on: - ready_for_review paths: - "e2e/**/*.py" - - "!e2e/preview/**/*.py" # See e2e_preview.yml - ".github/workflows/e2e.yml" env: @@ -31,7 +30,6 @@ jobs: folder: - "document_search" - "pipelines" - - "preview" runs-on: ubuntu-latest @@ -59,7 +57,7 @@ jobs: run: docker run -d -p 8080:8080 --name haystack_test_weaviate --env AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED='true' --env PERSISTENCE_DATA_PATH='/var/lib/weaviate' --env ENABLE_EXPERIMENTAL_BM25='true' --env DISK_USE_READONLY_PERCENTAGE='95' semitechnologies/weaviate:1.17.2 - name: Install Haystack - run: pip install -e .[inference,elasticsearch7,faiss,weaviate,opensearch,dev,pdf,preview] langdetect + run: pip install -e .[inference,elasticsearch7,faiss,weaviate,opensearch,dev,pdf] langdetect # FIXME caching prevents PRs from running the e2e tests properly diff --git a/.github/workflows/e2e_preview.yml b/.github/workflows/e2e_preview.yml deleted file mode 100644 index 2e9763d280..0000000000 --- a/.github/workflows/e2e_preview.yml +++ /dev/null @@ -1,42 +0,0 @@ -# If you change this name also do it in ci_metrics.yml -name: end-to-end (Preview) - -on: - workflow_dispatch: # Activate this workflow manually - schedule: - - cron: "0 0 * * *" - pull_request: - types: - - opened - - reopened - - synchronize - - ready_for_review - paths: - - "e2e/preview/**/*.py" - - ".github/workflows/e2e_preview.yml" - -env: - PYTHON_VERSION: "3.8" - OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - -jobs: - run: - timeout-minutes: 60 - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - - uses: actions/setup-python@v4 - with: - python-version: ${{ env.PYTHON_VERSION }} - - - name: Install dependencies - run: | - sudo apt-get update - sudo apt install ffmpeg # for local Whisper tests - - - name: Install Haystack - run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf tika 'azure-ai-formrecognizer>=3.2.0b2' - - - name: Run tests - run: pytest e2e/preview diff --git a/.github/workflows/examples_tests.yml b/.github/workflows/examples_tests.yml deleted file mode 100644 index c0a7f10887..0000000000 --- a/.github/workflows/examples_tests.yml +++ /dev/null @@ -1,82 +0,0 @@ -name: Examples tests - -on: - workflow_dispatch: # Activate this workflow manually - push: - branches: - - main - pull_request: - paths: - - "examples/**" - - "!examples/preview/**" - types: - - opened - - reopened - - synchronize - - ready_for_review - -env: - OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }} - ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} - HUGGINGFACE_API_KEY: ${{ secrets.HUGGINGFACE_API_KEY }} - PYTHON_VERSION: "3.8" - -jobs: - tests: - name: Examples - runs-on: ubuntu-latest - services: - elasticsearch: - image: elasticsearch:7.17.6 - env: - discovery.type: "single-node" - ES_JAVA_OPTS: "-Xms128m -Xmx256m" - ports: - - 9200:9200 - - steps: - - uses: actions/checkout@v4 - - - uses: actions/setup-python@v4 - with: - python-version: ${{ env.PYTHON_VERSION }} - - - name: Install Haystack - run: | - pip install --upgrade pip - pip install .[inference,dev,elasticsearch,preprocessing,file-conversion] - - - name: Run - run: pytest examples/ --ignore examples/preview/ - - - name: Calculate alert data - id: calculator - if: (success() || failure()) && github.ref_name == 'main' - shell: bash - run: | - if [ "${{ job.status }}" = "success" ]; then - echo "alert_type=success" >> "$GITHUB_OUTPUT"; - else - echo "alert_type=error" >> "$GITHUB_OUTPUT"; - fi - - - name: Send event to Datadog - if: (success() || failure()) && github.ref_name == 'main' - uses: masci/datadog@v1 - with: - api-key: ${{ secrets.CORE_DATADOG_API_KEY }} - api-url: https://api.datadoghq.eu - events: | - - title: "${{ github.workflow }} workflow" - text: "Job ${{ github.job }} in branch ${{ github.ref_name }}" - alert_type: "${{ steps.calculator.outputs.alert_type }}" - source_type_name: "Github" - host: ${{ github.repository_owner }} - tags: - - "project:${{ github.repository }}" - - "job:${{ github.job }}" - - "run_id:${{ github.run_id }}" - - "workflow:${{ github.workflow }}" - - "branch:${{ github.ref_name }}" - - "url:https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}" diff --git a/.github/workflows/github_release.yml b/.github/workflows/github_release.yml index 1940ac5c00..87a42c4e25 100644 --- a/.github/workflows/github_release.yml +++ b/.github/workflows/github_release.yml @@ -53,3 +53,9 @@ jobs: - name: Debug run: | cat relnotes.md + + - uses: ncipollo/release-action@v1 + with: + bodyFile: "relnotes.md" + prerelease: ${{ steps.version.outputs.current_pre_release }} + allowUpdates: true diff --git a/.github/workflows/license_compliance.yml b/.github/workflows/license_compliance.yml index b3e0aba19a..bbaa2395cb 100644 --- a/.github/workflows/license_compliance.yml +++ b/.github/workflows/license_compliance.yml @@ -42,10 +42,9 @@ jobs: # Exclusions in the vanilla distribution must be explicitly motivated # # - tqdm is MLP but there are no better alternatives - # - PyMuPDF is optional # - pinecone-client is optional # - psycopg2 is optional - exclude: "(?i)^(PyMuPDF|tqdm|pinecone-client|psycopg2).*" + exclude: "(?i)^(tqdm|pinecone-client|psycopg2).*" # We keep the license inventory on FOSSA - name: Send license report to Fossa @@ -199,7 +198,7 @@ jobs: # Special cases: # - pyzmq is flagged because dual-licensed, but we assume using BSD # - tqdm is MLP but there are no better alternatives - exclude: "(?i)^(astroid|certifi|chardet|num2words|nvidia-|pathspec|pinecone-client|psycopg2|pylint|PyMuPDF|pyzmq|tqdm).*" + exclude: "(?i)^(astroid|certifi|chardet|num2words|nvidia-|pathspec|pinecone-client|psycopg2|pylint|pyzmq|tqdm).*" - name: Print report if: ${{ always() }} @@ -272,7 +271,7 @@ jobs: # Special cases: # - pyzmq is flagged because dual-licensed, but we assume using BSD # - tqdm is MLP but there are no better alternatives - exclude: "(?i)^(astroid|certifi|chardet|num2words|nvidia-|pathspec|pinecone-client|psycopg2|pylint|PyMuPDF|pyzmq|tqdm).*" + exclude: "(?i)^(astroid|certifi|chardet|num2words|nvidia-|pathspec|pinecone-client|psycopg2|pylint|pyzmq|tqdm).*" - name: Print report if: ${{ always() }} diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index 2a8a6d96e7..fdd385a23c 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -6,9 +6,6 @@ on: paths: - "**.py" - "**/pyproject.toml" - - "!haystack/preview/**/*.py" - - "!test/preview/**/*.py" - - "!e2e/preview/**/*.py" env: PYTHON_VERSION: "3.8" @@ -74,7 +71,7 @@ jobs: - name: Install Haystack run: | - pip install ".[all,dev]" + pip install ".[all,dev]" pydoc-markdown pip install ./haystack-linter - name: Pylint diff --git a/.github/workflows/linting_preview.yml b/.github/workflows/linting_preview.yml deleted file mode 100644 index 1c7209d138..0000000000 --- a/.github/workflows/linting_preview.yml +++ /dev/null @@ -1,78 +0,0 @@ -# If you change this name also do it in linting-skipper.yml and ci_metrics.yml -name: Linting (Preview) - -on: - pull_request: - paths: - - "haystack/preview/**/*.py" - - "test/preview/**/*.py" - - "e2e/preview/**/*.py" - - "**/pyproject.toml" - -env: - PYTHON_VERSION: "3.8" - -jobs: - mypy: - runs-on: ubuntu-latest - steps: - - name: Checkout - uses: actions/checkout@v4 - with: - # With the default value of 1, there are corner cases where tj-actions/changed-files - # fails with a `no merge base` error - fetch-depth: 0 - - - name: Get changed files - id: files - uses: tj-actions/changed-files@v40 - with: - files: | - **/*.py - files_ignore: | - test/** - rest_api/test/** - - - uses: actions/setup-python@v4 - with: - python-version: ${{ env.PYTHON_VERSION }} - - - name: Install Haystack - run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf tika 'azure-ai-formrecognizer>=3.2.0b2' cohere - - - name: Mypy - if: steps.files.outputs.any_changed == 'true' - run: | - mkdir .mypy_cache/ - mypy --install-types --non-interactive ${{ steps.files.outputs.all_changed_files }} --exclude=rest_api/build/ --exclude=rest_api/test/ - - pylint: - runs-on: ubuntu-latest - steps: - - name: Checkout - uses: actions/checkout@v4 - with: - # With the default value of 1, there are corner cases where tj-actions/changed-files - # fails with a `no merge base` error - fetch-depth: 0 - - - name: Get changed files - id: files - uses: tj-actions/changed-files@v40 - with: - files: | - haystack/preview/**/*.py - - - uses: actions/setup-python@v4 - with: - python-version: ${{ env.PYTHON_VERSION }} - - - name: Install Haystack - run: | - pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2' cohere - pip install ./haystack-linter - - - name: Pylint - if: steps.files.outputs.any_changed == 'true' - run: | - pylint -ry -j 0 ${{ steps.files.outputs.all_changed_files }} diff --git a/.github/workflows/linting_skipper.yml b/.github/workflows/linting_skipper.yml index 0176579329..1430c845e6 100644 --- a/.github/workflows/linting_skipper.yml +++ b/.github/workflows/linting_skipper.yml @@ -6,9 +6,6 @@ on: paths-ignore: - "**.py" - "**/pyproject.toml" - - "!haystack/preview/**/*.py" - - "!test/preview/**/*.py" - - "!e2e/preview/**/*.py" jobs: mypy: diff --git a/.github/workflows/minor_version_release.yml b/.github/workflows/minor_version_release.yml index b402ca4ca3..ee7c42e185 100644 --- a/.github/workflows/minor_version_release.yml +++ b/.github/workflows/minor_version_release.yml @@ -1,4 +1,4 @@ -name: Minor Version Release +name: Minor Version Release (1.x) on: workflow_dispatch: @@ -12,6 +12,8 @@ jobs: steps: - name: Checkout this repo uses: actions/checkout@v4 + with: + ref: "v1.x" - name: Define all versions id: versions @@ -21,29 +23,28 @@ jobs: echo "current_release_minor=$(cut -d "." -f 1,2 < VERSION.txt)" >> "$GITHUB_OUTPUT" - name: Create new version branch - # We tag the commit where we branch off as "-rc0", so reno will know where to stop next - # time we generate release notes for "next minor". run: | git config --global user.name "github-actions[bot]" git config --global user.email "github-actions[bot]@users.noreply.github.com" - git tag -m"v${{ steps.versions.outputs.current_release_minor }}.0-rc0" v${{ steps.versions.outputs.current_release_minor }}.0-rc0 git checkout -b v${{ steps.versions.outputs.current_release_minor }}.x - git push -u origin v${{ steps.versions.outputs.current_release_minor }}.x --tags + git push -u origin v${{ steps.versions.outputs.current_release_minor }}.x - - name: Bump version on main + - name: Bump version on v1.x shell: bash env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: | - git checkout main + git checkout v1.x NEW_VERSION=$(awk -F. '/[0-9]+\./{$2++;print}' OFS=. < VERSION.txt) echo "$NEW_VERSION" > VERSION.txt cat VERSION.txt git add . git commit -m "Update unstable version to $NEW_VERSION" + # We tag the commit where we branch off as "-rc0", so reno will know where to stop next + # time we generate release notes for "next minor". VERSION_TAG="v$NEW_VERSION" git tag $VERSION_TAG -m"$VERSION_TAG" - git push --atomic origin main $VERSION_TAG + git push --atomic origin v1.x $VERSION_TAG - uses: actions/setup-python@v4 with: @@ -56,5 +57,5 @@ jobs: env: RDME_API_KEY: ${{ secrets.README_API_KEY }} run: | - git checkout main + git checkout v1.x python ./.github/utils/release_docs.py --new-version ${{ steps.versions.outputs.current_release_minor }} diff --git a/.github/workflows/pipeline_schema.yml b/.github/workflows/pipeline_schema.yml index b2b03e4f2c..22ef0e62e2 100644 --- a/.github/workflows/pipeline_schema.yml +++ b/.github/workflows/pipeline_schema.yml @@ -1,5 +1,5 @@ name: YAML Schema -run-name: Update schema for ref ${{ github.event.workflow_run.head_branch || inputs.ref || 'main' }} +run-name: Update schema for ref ${{ github.event.workflow_run.head_branch || inputs.ref || 'v1.x' }} on: workflow_dispatch: # Activate this workflow manually @@ -8,21 +8,23 @@ on: description: Tag or branch name of Haystack version required: true type: string - default: main + default: v1.x workflow_run: workflows: - Docker image release types: - completed + branches: + - v1.x env: - HAYSTACK_REF: ${{ github.event.workflow_run.head_branch || inputs.ref || 'main' }} + HAYSTACK_REF: ${{ github.event.workflow_run.head_branch || inputs.ref || 'v1.x' }} jobs: run: - name: Update schema for ref ${{ github.event.workflow_run.head_branch || inputs.ref || 'main' }} + name: Update schema for ref ${{ github.event.workflow_run.head_branch || inputs.ref || 'v1.x' }} if: ${{ !github.event.workflow_run || github.event.workflow_run.conclusion == 'success' }} runs-on: ubuntu-latest steps: diff --git a/.github/workflows/preview_imports.yml b/.github/workflows/preview_imports.yml deleted file mode 100644 index ba2d70f76d..0000000000 --- a/.github/workflows/preview_imports.yml +++ /dev/null @@ -1,55 +0,0 @@ -name: Verify preview imports only preview - -on: - pull_request: - types: - - opened - - reopened - - synchronize - - ready_for_review - paths: - - "haystack/preview/**.py" - -jobs: - verify-imports: - runs-on: ubuntu-latest - steps: - - name: Checkout - uses: actions/checkout@v4 - with: - # With the default value of 1, there are corner cases where tj-actions/changed-files - # fails with a `no merge base` error - fetch-depth: 0 - - - name: Get changed files - id: files - uses: tj-actions/changed-files@v40 - with: - files: | - haystack/preview/**.py - - - name: Check imports - shell: python - run: | - import re - regex = r"^(from haystack|import haystack)(?!\.preview| import preview)(.*)" - - changed_files = "${{ steps.files.outputs.all_changed_files }}".split() - matches = {} - for path in changed_files: - with open(path, "r") as f: - file_matches = [] - for line in f.readlines(): - file_matches.extend(re.finditer(regex, line.strip())) - if file_matches: - matches[path] = file_matches - - for path, match in matches.items(): - print(f"Bad imports in file '{path}'") - for m in match: - print(m.group()) - print() - - if matches: - print("::error:: Imports in haystack.preview can only import from haystack.preview") - import sys; sys.exit(1) diff --git a/.github/workflows/pypi_release.yml b/.github/workflows/pypi_release.yml index 3118198e65..38afaf35a5 100644 --- a/.github/workflows/pypi_release.yml +++ b/.github/workflows/pypi_release.yml @@ -1,9 +1,12 @@ name: Project release on PyPi on: + workflow_dispatch: # manually trigger the release process without a new tag push: tags: - - "v[0-9].[0-9]+.[0-9]+*" + - "v[0-9]+.[0-9]+.[0-9]+*" + # We must not release versions tagged with -rc0 suffix + - "!v[0-9]+.[0-9]+.[0-9]-rc0" jobs: release-on-pypi: diff --git a/.github/workflows/pypi_release_preview.yml b/.github/workflows/pypi_release_preview.yml deleted file mode 100644 index 91e9b0f7e4..0000000000 --- a/.github/workflows/pypi_release_preview.yml +++ /dev/null @@ -1,23 +0,0 @@ -name: Trigger preview release - -on: - push: - branches: - - main - paths: - - "haystack/preview/**.py" - -jobs: - release-on-pypi: - runs-on: ubuntu-latest - - steps: - - name: Trigger preview release - env: - HAYSTACK_BOT_REPO_DISPATCH_PA_TOKEN: ${{ secrets.HAYSTACK_BOT_REPO_DISPATCH_PA_TOKEN }} - run: | - curl -L \ - -X POST \ - -H "Authorization: Bearer $HAYSTACK_BOT_REPO_DISPATCH_PA_TOKEN" \ - https://api.github.com/repos/deepset-ai/haystack-preview-package/dispatches \ - -d '{"event_type":"preview_release"}' diff --git a/.github/workflows/readme_sync.yml b/.github/workflows/readme_sync.yml index f13389a70d..162c5aede7 100644 --- a/.github/workflows/readme_sync.yml +++ b/.github/workflows/readme_sync.yml @@ -6,9 +6,9 @@ on: - "docs/pydoc/**" push: branches: - - main + - v1.x # release branches have the form v1.9.x - - "v[0-9].*[0-9].x" + - "v1.*[0-9].x" jobs: sync: @@ -46,23 +46,13 @@ jobs: # Instead of putting more logic into the previous step, let's just assume that commits on `main` # will always be synced to the current `X.Y-unstable` version on Readme id: sync-main - if: github.ref_name == 'main' && github.event_name == 'push' + if: github.ref_name == 'v1.x' && github.event_name == 'push' uses: readmeio/rdme@8.3.1 env: README_API_KEY: ${{ secrets.README_API_KEY }} with: rdme: docs ./docs/pydoc/temp --key="$README_API_KEY" --version=${{ steps.current-version.outputs.minor }}-unstable - - name: Sync preview docs with 2.0 - # Sync the preview docs to the `2.0` version on Readme - id: sync-main-preview - if: github.ref_name == 'main' && github.event_name == 'push' - uses: readmeio/rdme@8.3.1 - env: - README_API_KEY: ${{ secrets.README_API_KEY }} - with: - rdme: docs ./docs/pydoc/temp-preview --key="$README_API_KEY" --version=2.0 - - name: Sync docs with current release # Mutually exclusive with the previous one, this step is supposed to only run on version branches. # Sync the current Haystack version `X.Y.Z` with its corresponding Readme version `X.Y`. diff --git a/.github/workflows/release_notes.yml b/.github/workflows/release_notes.yml index 2938053cae..1c50328819 100644 --- a/.github/workflows/release_notes.yml +++ b/.github/workflows/release_notes.yml @@ -36,7 +36,7 @@ jobs: if: steps.changed-files.outputs.any_changed == 'false' && !contains( github.event.pull_request.labels.*.name, 'ignore-for-release-notes') run: | # Check if any of the commit messages contain tags ci/docs/test - if git log --pretty=%s origin/main..HEAD | grep -E '^(ci:|docs:|test:)' > /dev/null; then + if git log --pretty=%s origin/v1.x..HEAD | grep -E '^(ci:|docs:|test:)' > /dev/null; then echo "Skipping release note check for commits with 'ci:', 'docs:', or 'test:' tags." else echo "::error::The release notes file is missing, please add one or attach the label 'ignore-for-release-notes' to this PR." diff --git a/.github/workflows/rest_api_tests.yml b/.github/workflows/rest_api_tests.yml index e3fc5a5971..2da6bf564b 100644 --- a/.github/workflows/rest_api_tests.yml +++ b/.github/workflows/rest_api_tests.yml @@ -110,7 +110,7 @@ jobs: - name: Install REST API run: | pip install -U "./rest_api[dev]" - pip install ".[inference,dev]" + pip install ".[inference,dev,preprocessing]" pip install . - name: Run tests diff --git a/.github/workflows/snippets_tests.yml b/.github/workflows/snippets_tests.yml deleted file mode 100644 index c12ea60099..0000000000 --- a/.github/workflows/snippets_tests.yml +++ /dev/null @@ -1,81 +0,0 @@ -name: Test documentation snippets for Haystack 2.x - -on: - workflow_dispatch: # Activate this workflow manually - push: - branches: - - main - pull_request: - paths: - - examples/preview/** - types: - - opened - - reopened - - synchronize - - ready_for_review - -env: - OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - PYTHON_VERSION: "3.8" - -jobs: - tests: - name: Snippets - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v4 - - - uses: actions/setup-python@v4 - with: - python-version: ${{ env.PYTHON_VERSION }} - - - name: Install snippets dependencies - run: | - pip install --upgrade pip - pip install ".[preview]" torch - - - name: Get changed files - id: files - uses: tj-actions/changed-files@v40 - with: - files: | - examples/preview/**.py - - - name: Run each snippet - run: | - CHANGED_FILES=${{ steps.files.outputs.all_changed_files }} - for file in $CHANGED_FILES; do - python "$file" - done - - - name: Calculate alert data - id: calculator - if: (success() || failure()) && github.ref_name == 'main' - shell: bash - run: | - if [ "${{ job.status }}" = "success" ]; then - echo "alert_type=success" >> "$GITHUB_OUTPUT"; - else - echo "alert_type=error" >> "$GITHUB_OUTPUT"; - fi - - - name: Send event to Datadog - if: (success() || failure()) && github.ref_name == 'main' - uses: masci/datadog@v1 - with: - api-key: ${{ secrets.CORE_DATADOG_API_KEY }} - api-url: https://api.datadoghq.eu - events: | - - title: "${{ github.workflow }} workflow" - text: "Job ${{ github.job }} in branch ${{ github.ref_name }}" - alert_type: "${{ steps.calculator.outputs.alert_type }}" - source_type_name: "Github" - host: ${{ github.repository_owner }} - tags: - - "project:${{ github.repository }}" - - "job:${{ github.job }}" - - "run_id:${{ github.run_id }}" - - "workflow:${{ github.workflow }}" - - "branch:${{ github.ref_name }}" - - "url:https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}" diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 307f3d2916..f973977be5 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -17,9 +17,6 @@ on: paths: - "**.py" - "pyproject.toml" - - "!haystack/preview/**/*.py" # See tests_preview.yml - - "!test/preview/**/*.py" # See tests_preview.yml - - "!e2e/preview/**/*.py" # See e2e_preview.yml - "!.github/**/*.py" - "!rest_api/**/*.py" - "!docs/**/*.py" @@ -125,10 +122,10 @@ jobs: include: - topic: document_stores os: ubuntu-latest - dependencies: elasticsearch8,faiss,weaviate,pinecone,opensearch,inference,crawler,preprocessing,file-conversion,pdf,ocr,metrics,dev + dependencies: elasticsearch8,faiss,weaviate,pinecone,opensearch,mongodb,inference,crawler,preprocessing,file-conversion,pdf,ocr,metrics,dev - topic: document_stores os: windows-latest - dependencies: elasticsearch8,faiss,weaviate,pinecone,opensearch,inference,crawler,preprocessing,file-conversion,pdf,ocr,metrics,dev + dependencies: elasticsearch8,faiss,weaviate,pinecone,opensearch,mongodb,inference,crawler,preprocessing,file-conversion,pdf,ocr,metrics,dev runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v4 @@ -143,16 +140,6 @@ jobs: - name: Run run: pytest --cov-report xml:coverage.xml --cov="haystack" -m "unit" test/${{ matrix.topic }} - - name: Coveralls Parallel - # We upload only coverage for ubuntu as handling both os - # complicates the workflow too much for little to no gain - if: matrix.os == 'ubuntu-latest' - uses: coverallsapp/github-action@v2 - with: - path-to-lcov: coverage.xml - flag-name: ${{ matrix.topic }} - parallel: true - - name: Calculate alert data id: calculator shell: bash @@ -184,15 +171,6 @@ jobs: - "branch:${{ github.ref_name }}" - "url:https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}" - upload-coverage: - needs: unit-tests - runs-on: ubuntu-latest - steps: - - name: Coveralls Finished - uses: coverallsapp/github-action@v2 - with: - parallel-finished: true - integration-tests-elasticsearch7: name: Integration / Elasticsearch7 / ${{ matrix.os }} needs: diff --git a/.github/workflows/tests_preview.yml b/.github/workflows/tests_preview.yml deleted file mode 100644 index f53ed289a7..0000000000 --- a/.github/workflows/tests_preview.yml +++ /dev/null @@ -1,318 +0,0 @@ -# If you change this name also do it in tests_preview_skipper.yml -name: Tests (Preview) - -on: - workflow_dispatch: # Activate this workflow manually - push: - branches: - - main - # release branches have the form v1.9.x - - "v[0-9].*[0-9].x" - pull_request: - types: - - opened - - reopened - - synchronize - - ready_for_review - paths: - - "haystack/preview/**/*.py" - - "test/preview/**/*.py" - -env: - OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }} - CORE_AZURE_CS_ENDPOINT: ${{ secrets.CORE_AZURE_CS_ENDPOINT }} - CORE_AZURE_CS_API_KEY: ${{ secrets.CORE_AZURE_CS_API_KEY }} - PYTHON_VERSION: "3.8" - -jobs: - black: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - - uses: actions/setup-python@v4 - with: - python-version: ${{ env.PYTHON_VERSION }} - - - name: Install Black - run: | - pip install --upgrade pip - pip install .[formatting] - - - name: Check status - run: | - if ! black . --check; then - git status - echo "###################################################################################################" - echo "# " - echo "# CHECK FAILED! Black found issues with your code formatting." - echo "# " - echo "# Either:" - echo "# 1. Run Black locally before committing:" - echo "# " - echo "# pip install .[formatting]" - echo "# black ." - echo "# " - echo "# 2. Install the pre-commit hook:" - echo "# " - echo "# pre-commit install" - echo "# " - echo "# 3. See https://github.com/deepset-ai/haystack/blob/main/CONTRIBUTING.md for help." - echo "# " - echo "# If you have further problems, please open an issue: https://github.com/deepset-ai/haystack/issues" - echo "# " - echo "##################################################################################################" - exit 1 - fi - - - name: Calculate alert data - id: calculator - shell: bash - if: (success() || failure()) && github.ref_name == 'main' - run: | - if [ "${{ job.status }}" = "success" ]; then - echo "alert_type=success" >> "$GITHUB_OUTPUT"; - else - echo "alert_type=error" >> "$GITHUB_OUTPUT"; - fi - - - name: Send event to Datadog - if: (success() || failure()) && github.ref_name == 'main' - uses: masci/datadog@v1 - with: - api-key: ${{ secrets.CORE_DATADOG_API_KEY }} - api-url: https://api.datadoghq.eu - events: | - - title: "${{ github.workflow }} workflow" - text: "Job ${{ github.job }} in branch ${{ github.ref_name }}" - alert_type: "${{ steps.calculator.outputs.alert_type }}" - source_type_name: "Github" - host: ${{ github.repository_owner }} - tags: - - "project:${{ github.repository }}" - - "job:${{ github.job }}" - - "run_id:${{ github.run_id }}" - - "workflow:${{ github.workflow }}" - - "branch:${{ github.ref_name }}" - - "url:https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}" - - unit-tests: - name: Unit / ${{ matrix.os }} - needs: black - strategy: - fail-fast: false - matrix: - os: - - ubuntu-latest - - windows-latest - - macos-latest - runs-on: ${{ matrix.os }} - steps: - - uses: actions/checkout@v4 - - - uses: actions/setup-python@v4 - with: - python-version: ${{ env.PYTHON_VERSION }} - - - name: Install Haystack - run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2' cohere - - - name: Run - run: pytest -m "not integration" test/preview - - - name: Calculate alert data - id: calculator - shell: bash - if: (success() || failure()) && github.ref_name == 'main' - run: | - if [ "${{ job.status }}" = "success" ]; then - echo "alert_type=success" >> "$GITHUB_OUTPUT"; - else - echo "alert_type=error" >> "$GITHUB_OUTPUT"; - fi - - - name: Send event to Datadog - if: (success() || failure()) && github.ref_name == 'main' - uses: masci/datadog@v1 - with: - api-key: ${{ secrets.CORE_DATADOG_API_KEY }} - api-url: https://api.datadoghq.eu - events: | - - title: "${{ github.workflow }} workflow" - text: "Job ${{ github.job }} in branch ${{ github.ref_name }}" - alert_type: "${{ steps.calculator.outputs.alert_type }}" - source_type_name: "Github" - host: ${{ github.repository_owner }} - tags: - - "project:${{ github.repository }}" - - "job:${{ github.job }}" - - "run_id:${{ github.run_id }}" - - "workflow:${{ github.workflow }}" - - "branch:${{ github.ref_name }}" - - "url:https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}" - - integration-tests-linux: - name: Integration / ubuntu-latest - needs: unit-tests - runs-on: ubuntu-latest - services: - tika: - image: apache/tika:2.9.0.0 - ports: - - 9998:9998 - steps: - - uses: actions/checkout@v4 - - - uses: actions/setup-python@v4 - with: - python-version: ${{ env.PYTHON_VERSION }} - - - name: Install dependencies - run: | - sudo apt update - sudo apt install ffmpeg # for local Whisper tests - - - name: Install Haystack - run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2' cohere - - - name: Run - run: pytest --maxfail=5 -m "integration" test/preview - - - name: Calculate alert data - id: calculator - shell: bash - if: (success() || failure()) && github.ref_name == 'main' - run: | - if [ "${{ job.status }}" = "success" ]; then - echo "alert_type=success" >> "$GITHUB_OUTPUT"; - else - echo "alert_type=error" >> "$GITHUB_OUTPUT"; - fi - - - name: Send event to Datadog - if: (success() || failure()) && github.ref_name == 'main' - uses: masci/datadog@v1 - with: - api-key: ${{ secrets.CORE_DATADOG_API_KEY }} - api-url: https://api.datadoghq.eu - events: | - - title: "${{ github.workflow }} workflow" - text: "Job ${{ github.job }} in branch ${{ github.ref_name }}" - alert_type: "${{ steps.calculator.outputs.alert_type }}" - source_type_name: "Github" - host: ${{ github.repository_owner }} - tags: - - "project:${{ github.repository }}" - - "job:${{ github.job }}" - - "run_id:${{ github.run_id }}" - - "workflow:${{ github.workflow }}" - - "branch:${{ github.ref_name }}" - - "url:https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}" - - integration-tests-macos: - name: Integration / macos-latest - needs: unit-tests - runs-on: macos-latest-xl - env: - HAYSTACK_MPS_ENABLED: false - steps: - - uses: actions/checkout@v4 - - - uses: actions/setup-python@v4 - with: - python-version: ${{ env.PYTHON_VERSION }} - - - name: Install dependencies - run: | - brew install ffmpeg # for local Whisper tests - brew install docker - colima start - - - name: Install Haystack - run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2' cohere - - - name: Run Tika - run: docker run -d -p 9998:9998 apache/tika:2.9.0.0 - - - name: Run - run: pytest --maxfail=5 -m "integration" test/preview - - - name: Calculate alert data - id: calculator - shell: bash - if: (success() || failure()) && github.ref_name == 'main' - run: | - if [ "${{ job.status }}" = "success" ]; then - echo "alert_type=success" >> "$GITHUB_OUTPUT"; - else - echo "alert_type=error" >> "$GITHUB_OUTPUT"; - fi - - - name: Send event to Datadog - if: (success() || failure()) && github.ref_name == 'main' - uses: masci/datadog@v1 - with: - api-key: ${{ secrets.CORE_DATADOG_API_KEY }} - api-url: https://api.datadoghq.eu - events: | - - title: "${{ github.workflow }} workflow" - text: "Job ${{ github.job }} in branch ${{ github.ref_name }}" - alert_type: "${{ steps.calculator.outputs.alert_type }}" - source_type_name: "Github" - host: ${{ github.repository_owner }} - tags: - - "project:${{ github.repository }}" - - "job:${{ github.job }}" - - "run_id:${{ github.run_id }}" - - "workflow:${{ github.workflow }}" - - "branch:${{ github.ref_name }}" - - "url:https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}" - - integration-tests-windows: - name: Integration / windows-latest - needs: unit-tests - runs-on: windows-latest - steps: - - uses: actions/checkout@v4 - - - uses: actions/setup-python@v4 - with: - python-version: ${{ env.PYTHON_VERSION }} - - - name: Install Haystack - run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2' cohere - - - name: Run - run: pytest --maxfail=5 -m "integration" test/preview -k 'not tika' - - - name: Calculate alert data - id: calculator - shell: bash - if: (success() || failure()) && github.ref_name == 'main' - run: | - if [ "${{ job.status }}" = "success" ]; then - echo "alert_type=success" >> "$GITHUB_OUTPUT"; - else - echo "alert_type=error" >> "$GITHUB_OUTPUT"; - fi - - - name: Send event to Datadog - if: (success() || failure()) && github.ref_name == 'main' - uses: masci/datadog@v1 - with: - api-key: ${{ secrets.CORE_DATADOG_API_KEY }} - api-url: https://api.datadoghq.eu - events: | - - title: "${{ github.workflow }} workflow" - text: "Job ${{ github.job }} in branch ${{ github.ref_name }}" - alert_type: "${{ steps.calculator.outputs.alert_type }}" - source_type_name: "Github" - host: ${{ github.repository_owner }} - tags: - - "project:${{ github.repository }}" - - "job:${{ github.job }}" - - "run_id:${{ github.run_id }}" - - "workflow:${{ github.workflow }}" - - "branch:${{ github.ref_name }}" - - "url:https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}" diff --git a/.github/workflows/tests_preview_skipper.yml b/.github/workflows/tests_preview_skipper.yml deleted file mode 100644 index 2f64eaae99..0000000000 --- a/.github/workflows/tests_preview_skipper.yml +++ /dev/null @@ -1,21 +0,0 @@ -# If you change this name also do it in tests_preview.yml -name: Tests (Preview) - -on: - pull_request: - types: - - opened - - reopened - - synchronize - - ready_for_review - paths-ignore: - - "haystack/preview/**/*.py" - - "test/preview/**/*.py" - -jobs: - catch-all: - name: Catch-all check - runs-on: ubuntu-latest - steps: - - name: Skip preview tests - run: echo "Skipped!" diff --git a/.github/workflows/tests_skipper.yml b/.github/workflows/tests_skipper.yml index 5ac396aebd..c4e4dbae09 100644 --- a/.github/workflows/tests_skipper.yml +++ b/.github/workflows/tests_skipper.yml @@ -11,9 +11,6 @@ on: paths-ignore: - "**.py" - "pyproject.toml" - - "!haystack/preview/**/*.py" # See tests_preview.yml - - "!test/preview/**/*.py" # See tests_preview.yml - - "!e2e/preview/**/*.py" # See e2e_preview.yml - "!.github/**/*.py" - "!rest_api/**/*.py" - "!docs/**/*.py" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5375c94d75..a8108b9190 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,7 +27,7 @@ repos: - id: ruff - repo: https://github.com/codespell-project/codespell - rev: v2.2.5 + rev: b8ecc9b3acf31690c3cb2fc5bb03a3fbbbc2d7a3 hooks: - id: codespell additional_dependencies: diff --git a/README.md b/README.md index 119515107e..8a724eeedd 100644 --- a/README.md +++ b/README.md @@ -1,66 +1,28 @@ -
- Haystack - -| | | -| ------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| CI/CD | [![Tests](https://github.com/deepset-ai/haystack/actions/workflows/tests.yml/badge.svg)](https://github.com/deepset-ai/haystack/actions/workflows/tests.yml) [![Docker image release](https://github.com/deepset-ai/haystack/actions/workflows/docker_release.yml/badge.svg)](https://github.com/deepset-ai/haystack/actions/workflows/docker_release.yml) [![Schemas](https://github.com/deepset-ai/haystack/actions/workflows/schemas.yml/badge.svg)](https://github.com/deepset-ai/haystack/actions/workflows/schemas.yml) [![code style - Black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![types - Mypy](https://img.shields.io/badge/types-Mypy-blue.svg)](https://github.com/python/mypy) [![Coverage Status](https://coveralls.io/repos/github/deepset-ai/haystack/badge.svg?branch=main)](https://coveralls.io/github/deepset-ai/haystack?branch=main) | -| Docs | [![Sync docs with Readme](https://github.com/deepset-ai/haystack/actions/workflows/readme_sync.yml/badge.svg)](https://github.com/deepset-ai/haystack/actions/workflows/readme_sync.yml) [![Website](https://img.shields.io/website?label=documentation&up_message=online&url=https%3A%2F%2Ffanyv88.com%3A443%2Fhttps%2Fdocs.haystack.deepset.ai)](https://docs.haystack.deepset.ai) | -| Package | [![PyPI](https://img.shields.io/pypi/v/farm-haystack)](https://pypi.org/project/farm-haystack/) ![PyPI - Downloads](https://img.shields.io/pypi/dm/farm-haystack?color=blue&logo=pypi&logoColor=gold) ![PyPI - Python Version](https://img.shields.io/pypi/pyversions/farm-haystack?logo=python&logoColor=gold) [![GitHub](https://img.shields.io/github/license/deepset-ai/haystack?color=blue)](LICENSE) [![License Compliance](https://github.com/deepset-ai/haystack/actions/workflows/license_compliance.yml/badge.svg)](https://github.com/deepset-ai/haystack/actions/workflows/license_compliance.yml) | -| Meta | [![Discord](https://img.shields.io/discord/993534733298450452?logo=discord)](https://discord.gg/haystack) [![Twitter Follow](https://img.shields.io/twitter/follow/haystack_ai)](https://twitter.com/haystack_ai) | -
+> ⚠️ **End of Life Notice** +> +> Haystack version 1.x reached End of Life (EOL) on March 11, 2025, and is no longer receiving updates or support. The final version released is 1.26.4. We recommend migrating to Haystack version 2.x, which has been stable and available since March 2024. It is distributed via a different package named [haystack-ai](https://pypi.org/project/haystack-ai/). +> +> **Why Upgrade to Haystack 2.x?** +> +> - More Flexible & Composable Pipelines: We introduced cyclic [pipeline](https://docs.haystack.deepset.ai/docs/pipelines) graphs, allowing for loops, condition-based routing and concurrent execution, which are essential for modern LLM applications. +> - Customizable & Extensible Components: While there are many ready-made components, including an Agent component, creating [custom components](https://docs.haystack.deepset.ai/docs/custom-components) is a core functionality and all you need is to decorate your custom logic with @component. +> - Improved 70+ Integrations: Unified interfaces for document stores and also for chat generators support a broad range of vector databases. Plus, all [integrations](https://haystack.deepset.ai/integrations) are build for robust, build for real-world production use and tested nightly. +> - Production-Ready Features: Enhanced, [structured logging](https://docs.haystack.deepset.ai/docs/logging), [tracing](https://docs.haystack.deepset.ai/docs/tracing), and [Hayhooks](https://docs.haystack.deepset.ai/docs/hayhooks) make it easy to deploy and serve pipelines as RESTful APIs. +> +> **Migration Resources:** +> - [Migration Guide](https://docs.haystack.deepset.ai/docs/migration) - Learn how to migrate your applications to Haystack 2.x +> - [Historical Documentation](https://core-engineering.s3.eu-central-1.amazonaws.com/public/docs/haystack-v1-docs.zip) - Download the complete documentation for Haystack 1.x (versions 1.0 to 1.26) +> - [GitHub History](https://github.com/deepset-ai/haystack-tutorials/tree/5917718cbfbb61410aab4121ee6fe754040a5dc7) - Access old tutorials and examples in the repository history +> +> **Important Migration Note:** +> The package name has changed from `farm-haystack` to `haystack-ai`. These packages cannot coexist in the same Python environment. To migrate: +> ```bash +> pip uninstall -y farm-haystack haystack-ai +> pip install haystack-ai +> ``` [Haystack](https://haystack.deepset.ai/) is an end-to-end NLP framework that enables you to build applications powered by LLMs, Transformer models, vector search and more. Whether you want to perform question answering, answer generation, semantic document search, or build tools that are capable of complex decision-making and query resolution, you can use state-of-the-art NLP models with Haystack to build end-to-end NLP applications to solve your use case. -## Quickstart - -Haystack is built around the concept of pipelines. A pipeline is a powerful structure that performs an NLP task. It's made up of components connected together. For example, you can connect a `Retriever` and a `PromptNode` to build a Generative Question Answering pipeline that uses your own data. - -Try out how Haystack answers questions about Game of Thrones using the Retrieval Augmented Generation (RAG) approach 👇 - -First, run the minimal Haystack installation: - -```sh -pip install farm-haystack -``` - -Then, index your data to the DocumentStore, build a RAG pipeline, and ask a question on your data: - -```python -from haystack.document_stores import InMemoryDocumentStore -from haystack.utils import build_pipeline, add_example_data, print_answers - -# We are model agnostic :) Here, you can choose from: "anthropic", "cohere", "huggingface", and "openai". -provider = "openai" -API_KEY = "sk-..." # ADD YOUR KEY HERE - -# We support many different databases. Here, we load a simple and lightweight in-memory database. -document_store = InMemoryDocumentStore(use_bm25=True) - -# Download and add Game of Thrones TXT articles to Haystack DocumentStore. -# You can also provide a folder with your local documents. -add_example_data(document_store, "data/GoT_getting_started") - -# Build a pipeline with a Retriever to get relevant documents to the query and a PromptNode interacting with LLMs using a custom prompt. -pipeline = build_pipeline(provider, API_KEY, document_store) - -# Ask a question on the data you just added. -result = pipeline.run(query="Who is the father of Arya Stark?") - -# For details, like which documents were used to generate the answer, look into the object -print_answers(result, details="medium") -``` - -The output of the pipeline will reference the documents used to generate the answer: - -``` -'Query: Who is the father of Arya Stark?' -'Answers:' -[{'answer': 'The father of Arya Stark is Lord Eddard Stark of ' - 'Winterfell. [Document 1, Document 4, Document 5]'}] -``` - -Congratulations, you have just built your first Haystack app! - ## Core Concepts 🏃‍♀️ **[Pipelines](https://docs.haystack.deepset.ai/docs/pipelines):** This is the standard Haystack structure that builds on top of your data to perform various NLP tasks such as retrieval augmented generation, question answering and more. The data in a Pipeline flows from one Node to the next. You define how Nodes interact with each other and how one Node pushes data to the next. @@ -95,19 +57,19 @@ An example pipeline would consist of one `Retriever` Node and one `PromptNode`. - **Continuous Learning**: Collect new training data from user feedback in production & improve your models continuously. ## Resources -| | | -| ---------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| 📒 [Docs](https://docs.haystack.deepset.ai) | Components, Pipeline Nodes, Guides, API Reference | -| 💾 [Installation](https://github.com/deepset-ai/haystack#-installation) | How to install Haystack | -| 🎓 [Tutorials](https://haystack.deepset.ai/tutorials) | See what Haystack can do with our Notebooks & Scripts | -| 🎉 [Haystack Extras](https://github.com/deepset-ai/haystack-extras) | A repository that lists extra Haystack packages and components that can be installed separately. | -| 🔰 [Demos](https://github.com/deepset-ai/haystack-demos) | A repository containing Haystack demo applications with Docker Compose and a REST API | +| | | +| ---------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| 📒 [Docs](https://docs.haystack.deepset.ai) | Components, Pipeline Nodes, Guides, API Reference | +| 💾 [Installation](https://github.com/deepset-ai/haystack#-installation) | How to install Haystack | +| 🎓 [Tutorials](https://haystack.deepset.ai/tutorials) | See what Haystack can do with our Notebooks & Scripts | +| 🎉 [Haystack Extras](https://github.com/deepset-ai/haystack-extras) | A repository that lists extra Haystack packages and components that can be installed separately. | +| 🔰 [Demos](https://github.com/deepset-ai/haystack-demos) | A repository containing Haystack demo applications with Docker Compose and a REST API | | 🖖 [Community](https://github.com/deepset-ai/haystack#-community) | [Discord](https://discord.gg/haystack), [𝕏 (Twitter)](https://twitter.com/haystack_ai), [Stack Overflow](https://stackoverflow.com/questions/tagged/haystack), [GitHub Discussions](https://github.com/deepset-ai/haystack/discussions) | -| 💙 [Contributing](https://github.com/deepset-ai/haystack#-contributing) | We welcome all contributions! | -| 📊 [Benchmarks](https://haystack.deepset.ai/benchmarks/) | Speed & Accuracy of Retriever, Readers and DocumentStores | -| 🔭 [Roadmap](https://haystack.deepset.ai/overview/roadmap) | Public roadmap of Haystack | -| 📰 [Blog](https://haystack.deepset.ai/blog) | Learn about the latest with Haystack and NLP | -| ☎️ [Jobs](https://www.deepset.ai/jobs) | We're hiring! Have a look at our open positions | +| 💙 [Contributing](https://github.com/deepset-ai/haystack#-contributing) | We welcome all contributions! | +| 📊 [Benchmarks](https://haystack.deepset.ai/benchmarks/) | Speed & Accuracy of Retriever, Readers and DocumentStores | +| 🔭 [Roadmap](https://haystack.deepset.ai/overview/roadmap) | Public roadmap of Haystack | +| 📰 [Blog](https://haystack.deepset.ai/blog) | Learn about the latest with Haystack and NLP | +| ☎️ [Jobs](https://www.deepset.ai/jobs) | We're hiring! Have a look at our open positions | ## 💾 Installation diff --git a/VERSION.txt b/VERSION.txt index 000d4b1086..7464078b57 100644 --- a/VERSION.txt +++ b/VERSION.txt @@ -1 +1 @@ -1.23.0-rc0 +1.26.4.post diff --git a/conftest.py b/conftest.py index 7d27f2659d..475c5dcc91 100644 --- a/conftest.py +++ b/conftest.py @@ -11,9 +11,6 @@ def pytest_addoption(parser): def pytest_generate_tests(metafunc): - # Ugly hack to avoid polluting preview tests this with unwanted fixtures - if "test/preview" in metafunc.module.__file__ or "test\\preview" in metafunc.module.__file__: - return # Get selected docstores from CLI arg document_store_type = metafunc.config.option.document_store_type selected_doc_stores = [item.strip() for item in document_store_type.split(",")] diff --git a/docs/pydoc/config-preview/builder.yml b/docs/pydoc/config-preview/builder.yml deleted file mode 100644 index 7ec4b36141..0000000000 --- a/docs/pydoc/config-preview/builder.yml +++ /dev/null @@ -1,26 +0,0 @@ -loaders: - - type: loaders.CustomPythonLoader - search_path: [../../../haystack/preview/components/builders] - modules: ["answer_builder", "prompt_builder", "dynamic_prompt_builder"] - ignore_when_discovered: ["__init__"] -processors: - - type: filter - expression: - documented_only: true - do_not_filter_modules: false - skip_empty_modules: true - - type: smart - - type: crossref -renderer: - type: renderers.ReadmePreviewRenderer - excerpt: Extract the output of a Generator to an Answer format, and build prompts. - category_slug: haystack-classes - title: Builder API - slug: builder-api - order: 5 - markdown: - descriptive_class_title: false - descriptive_module_title: true - add_method_class_prefix: true - add_member_class_prefix: false - filename: builder_api.md diff --git a/docs/pydoc/config-preview/caching.yml b/docs/pydoc/config-preview/caching.yml deleted file mode 100644 index 7f2bc0e852..0000000000 --- a/docs/pydoc/config-preview/caching.yml +++ /dev/null @@ -1,26 +0,0 @@ -loaders: - - type: loaders.CustomPythonLoader - search_path: [../../../haystack/preview/components/caching] - modules: ["url_cache_checker"] - ignore_when_discovered: ["__init__"] -processors: - - type: filter - expression: - documented_only: true - do_not_filter_modules: false - skip_empty_modules: true - - type: smart - - type: crossref -renderer: - type: renderers.ReadmePreviewRenderer - excerpt: Checks if any document coming from the given URL is already present in the store. - category_slug: haystack-classes - title: UrlCacheChecker API - slug: caching-api - order: 160 - markdown: - descriptive_class_title: false - descriptive_module_title: true - add_method_class_prefix: true - add_member_class_prefix: false - filename: caching_api.md diff --git a/docs/pydoc/config-preview/classifier.yml b/docs/pydoc/config-preview/classifier.yml deleted file mode 100644 index 8a68fa696d..0000000000 --- a/docs/pydoc/config-preview/classifier.yml +++ /dev/null @@ -1,26 +0,0 @@ -loaders: - - type: loaders.CustomPythonLoader - search_path: [../../../haystack/preview/components/classifiers] - modules: ["document_language_classifier"] - ignore_when_discovered: ["__init__"] -processors: - - type: filter - expression: - documented_only: true - do_not_filter_modules: false - skip_empty_modules: true - - type: smart - - type: crossref -renderer: - type: renderers.ReadmePreviewRenderer - excerpt: Detects the language of the Documents and routes them appropriately. - category_slug: haystack-classes - title: Language Classifier API - slug: language-classifier-api - order: 10 - markdown: - descriptive_class_title: false - descriptive_module_title: true - add_method_class_prefix: true - add_member_class_prefix: false - filename: language_classifier_api.md diff --git a/docs/pydoc/config-preview/converter.yml b/docs/pydoc/config-preview/converter.yml deleted file mode 100644 index b6f3aae0df..0000000000 --- a/docs/pydoc/config-preview/converter.yml +++ /dev/null @@ -1,26 +0,0 @@ -loaders: - - type: loaders.CustomPythonLoader - search_path: [../../../haystack/preview/components/converters] - modules: ["azure", "html", "markdown", "pypdf", "tika", "txt"] - ignore_when_discovered: ["__init__"] -processors: - - type: filter - expression: - documented_only: true - do_not_filter_modules: false - skip_empty_modules: true - - type: smart - - type: crossref -renderer: - type: renderers.ReadmePreviewRenderer - excerpt: Extracts text from files in different formats and converts it into the unified Document format. - category_slug: haystack-classes - title: Converter API - slug: converter-api - order: 50 - markdown: - descriptive_class_title: false - descriptive_module_title: true - add_method_class_prefix: true - add_member_class_prefix: false - filename: converter_api.md diff --git a/docs/pydoc/config-preview/document_store.yml b/docs/pydoc/config-preview/document_store.yml deleted file mode 100644 index 6c51c46290..0000000000 --- a/docs/pydoc/config-preview/document_store.yml +++ /dev/null @@ -1,26 +0,0 @@ -loaders: - - type: loaders.CustomPythonLoader - search_path: [../../../haystack/preview/document_stores/in_memory] - modules: ["document_store"] - ignore_when_discovered: ["__init__"] -processors: - - type: filter - expression: - documented_only: true - do_not_filter_modules: false - skip_empty_modules: true - - type: smart - - type: crossref -renderer: - type: renderers.ReadmePreviewRenderer - excerpt: Stores your texts and meta data and provides them to the Retriever at query time. - category_slug: haystack-classes - title: DocumentStore API - slug: document-store-api - order: 20 - markdown: - descriptive_class_title: false - descriptive_module_title: true - add_method_class_prefix: true - add_member_class_prefix: false - filename: document_store.md diff --git a/docs/pydoc/config-preview/embedder.yml b/docs/pydoc/config-preview/embedder.yml deleted file mode 100644 index 077d292dc2..0000000000 --- a/docs/pydoc/config-preview/embedder.yml +++ /dev/null @@ -1,26 +0,0 @@ -loaders: - - type: loaders.CustomPythonLoader - search_path: [../../../haystack/preview/components/embedders] - modules: ["openai_document_embedder", "openai_text_embedder", "sentence_transformers_document_embedder", "sentence_transformers_text_embedder"] - ignore_when_discovered: ["__init__"] -processors: - - type: filter - expression: - documented_only: true - do_not_filter_modules: false - skip_empty_modules: true - - type: smart - - type: crossref -renderer: - type: renderers.ReadmePreviewRenderer - excerpt: Transforms queries into vectors to look for similar or relevant Documents. - category_slug: haystack-classes - title: Embedder API - slug: embedder-api - order: 40 - markdown: - descriptive_class_title: false - descriptive_module_title: true - add_method_class_prefix: true - add_member_class_prefix: false - filename: embedder_api.md diff --git a/docs/pydoc/config-preview/fetcher.yml b/docs/pydoc/config-preview/fetcher.yml deleted file mode 100644 index 7d7a6ee4b9..0000000000 --- a/docs/pydoc/config-preview/fetcher.yml +++ /dev/null @@ -1,26 +0,0 @@ -loaders: - - type: loaders.CustomPythonLoader - search_path: [../../../haystack/preview/components/fetchers] - modules: ["link_content"] - ignore_when_discovered: ["__init__"] -processors: - - type: filter - expression: - documented_only: true - do_not_filter_modules: false - skip_empty_modules: true - - type: smart - - type: crossref -renderer: - type: renderers.ReadmePreviewRenderer - excerpt: Fetches content from a list of URLs and returns a list of extracted content streams. - category_slug: haystack-classes - title: LinkContentFetcher API - slug: fetcher-api - order: 70 - markdown: - descriptive_class_title: false - descriptive_module_title: true - add_method_class_prefix: true - add_member_class_prefix: false - filename: fetcher_api.md diff --git a/docs/pydoc/config-preview/generator.yml b/docs/pydoc/config-preview/generator.yml deleted file mode 100644 index dbed3a820a..0000000000 --- a/docs/pydoc/config-preview/generator.yml +++ /dev/null @@ -1,26 +0,0 @@ -loaders: - - type: loaders.CustomPythonLoader - search_path: [../../../haystack/preview/components/generators] - modules: ["hugging_face_local", "hugging_face_tgi", "openai", "chat/hugging_face_tgi", "chat/openai"] - ignore_when_discovered: ["__init__"] -processors: - - type: filter - expression: - documented_only: true - do_not_filter_modules: false - skip_empty_modules: true - - type: smart - - type: crossref -renderer: - type: renderers.ReadmePreviewRenderer - excerpt: Enables text generation using LLMs. - category_slug: haystack-classes - title: Generator API - slug: generator-api - order: 60 - markdown: - descriptive_class_title: false - descriptive_module_title: true - add_method_class_prefix: true - add_member_class_prefix: false - filename: generator_api.md diff --git a/docs/pydoc/config-preview/pipeline.yml b/docs/pydoc/config-preview/pipeline.yml deleted file mode 100644 index 2b61de5451..0000000000 --- a/docs/pydoc/config-preview/pipeline.yml +++ /dev/null @@ -1,26 +0,0 @@ -loaders: - - type: loaders.CustomPythonLoader - search_path: [../../../haystack/preview] - modules: ["pipeline"] - ignore_when_discovered: ["__init__"] -processors: - - type: filter - expression: - documented_only: true - do_not_filter_modules: false - skip_empty_modules: true - - type: smart - - type: crossref -renderer: - type: renderers.ReadmePreviewRenderer - excerpt: Arranges components and integrations in flow. - category_slug: haystack-classes - title: Pipelines API - slug: pipelines-api - order: 80 - markdown: - descriptive_class_title: false - descriptive_module_title: true - add_method_class_prefix: true - add_member_class_prefix: false - filename: pipelines_api.md diff --git a/docs/pydoc/config-preview/preprocessor.yml b/docs/pydoc/config-preview/preprocessor.yml deleted file mode 100644 index d2acf19da4..0000000000 --- a/docs/pydoc/config-preview/preprocessor.yml +++ /dev/null @@ -1,26 +0,0 @@ -loaders: - - type: loaders.CustomPythonLoader - search_path: [../../../haystack/preview/components/preprocessors] - modules: ["document_cleaner", "document_splitter"] - ignore_when_discovered: ["__init__"] -processors: - - type: filter - expression: - documented_only: true - do_not_filter_modules: false - skip_empty_modules: true - - type: smart - - type: crossref -renderer: - type: renderers.ReadmePreviewRenderer - excerpt: Normalizes white spaces, gets rid of headers and footers, cleans empty lines in your Documents, or splits them into smaller pieces. - category_slug: haystack-classes - title: PreProcessor API - slug: preprocessor-api - order: 90 - markdown: - descriptive_class_title: false - descriptive_module_title: true - add_method_class_prefix: true - add_member_class_prefix: false - filename: preprocessor_api.md diff --git a/docs/pydoc/config-preview/ranker.yml b/docs/pydoc/config-preview/ranker.yml deleted file mode 100644 index d0fbe6d687..0000000000 --- a/docs/pydoc/config-preview/ranker.yml +++ /dev/null @@ -1,26 +0,0 @@ -loaders: - - type: loaders.CustomPythonLoader - search_path: [../../../haystack/preview/components/rankers] - modules: ["meta_field", "transformers_similarity"] - ignore_when_discovered: ["__init__"] -processors: - - type: filter - expression: - documented_only: true - do_not_filter_modules: false - skip_empty_modules: true - - type: smart - - type: crossref -renderer: - type: renderers.ReadmePreviewRenderer - excerpt: Reorders a set of Documents based on their relevance to the query. - category_slug: haystack-classes - title: Ranker API - slug: ranker-api - order: 110 - markdown: - descriptive_class_title: false - descriptive_module_title: true - add_method_class_prefix: true - add_member_class_prefix: false - filename: ranker_api.md diff --git a/docs/pydoc/config-preview/reader.yml b/docs/pydoc/config-preview/reader.yml deleted file mode 100644 index 59163e4422..0000000000 --- a/docs/pydoc/config-preview/reader.yml +++ /dev/null @@ -1,26 +0,0 @@ -loaders: - - type: loaders.CustomPythonLoader - search_path: [../../../haystack/preview/components/readers] - modules: ["extractive"] - ignore_when_discovered: ["__init__"] -processors: - - type: filter - expression: - documented_only: true - do_not_filter_modules: false - skip_empty_modules: true - - type: smart - - type: crossref -renderer: - type: renderers.ReadmePreviewRenderer - excerpt: Takes a query and a set of Documents as input and returns ExtractedAnswers by selecting a text span within the Documents. - category_slug: haystack-classes - title: Reader API - slug: reader-api - order: 120 - markdown: - descriptive_class_title: false - descriptive_module_title: true - add_method_class_prefix: true - add_member_class_prefix: false - filename: reader_api.md diff --git a/docs/pydoc/config-preview/retriever.yml b/docs/pydoc/config-preview/retriever.yml deleted file mode 100644 index 805753de56..0000000000 --- a/docs/pydoc/config-preview/retriever.yml +++ /dev/null @@ -1,26 +0,0 @@ -loaders: - - type: loaders.CustomPythonLoader - search_path: [../../../haystack/preview/components/retrievers] - modules: ["in_memory_bm25_retriever", "in_memory_embedding_retriever"] - ignore_when_discovered: ["__init__"] -processors: - - type: filter - expression: - documented_only: true - do_not_filter_modules: false - skip_empty_modules: true - - type: smart - - type: crossref -renderer: - type: renderers.ReadmePreviewRenderer - excerpt: Sweeps through a Document Store and returns a set of candidate Documents that are relevant to the query. - category_slug: haystack-classes - title: Retriever API - slug: retriever-api - order: 130 - markdown: - descriptive_class_title: false - descriptive_module_title: true - add_method_class_prefix: true - add_member_class_prefix: false - filename: retriever_api.md diff --git a/docs/pydoc/config-preview/router.yml b/docs/pydoc/config-preview/router.yml deleted file mode 100644 index 4f01b3a869..0000000000 --- a/docs/pydoc/config-preview/router.yml +++ /dev/null @@ -1,26 +0,0 @@ -loaders: - - type: loaders.CustomPythonLoader - search_path: [../../../haystack/preview/components/routers] - modules: ["document_joiner", "file_type_router", "metadata_router", "text_language_router"] - ignore_when_discovered: ["__init__"] -processors: - - type: filter - expression: - documented_only: true - do_not_filter_modules: false - skip_empty_modules: true - - type: smart - - type: crossref -renderer: - type: renderers.ReadmePreviewRenderer - excerpt: Routes data to the right component based on its file type or metadata. - category_slug: haystack-classes - title: Router API - slug: router-api - order: 140 - markdown: - descriptive_class_title: false - descriptive_module_title: true - add_method_class_prefix: true - add_member_class_prefix: false - filename: router_api.md diff --git a/docs/pydoc/config-preview/sampler.yml b/docs/pydoc/config-preview/sampler.yml deleted file mode 100644 index fe1072ab6d..0000000000 --- a/docs/pydoc/config-preview/sampler.yml +++ /dev/null @@ -1,26 +0,0 @@ -loaders: - - type: loaders.CustomPythonLoader - search_path: [../../../haystack/preview/components/samplers] - modules: ["top_p"] - ignore_when_discovered: ["__init__"] -processors: - - type: filter - expression: - documented_only: true - do_not_filter_modules: false - skip_empty_modules: true - - type: smart - - type: crossref -renderer: - type: renderers.ReadmePreviewRenderer - excerpt: Filters documents based on their similarity scores using top-p sampling. - category_slug: haystack-classes - title: TopPSampler API - slug: sampler-api - order: 150 - markdown: - descriptive_class_title: false - descriptive_module_title: true - add_method_class_prefix: true - add_member_class_prefix: false - filename: sampler_api.md diff --git a/docs/pydoc/config-preview/websearch.yml b/docs/pydoc/config-preview/websearch.yml deleted file mode 100644 index 6433940e04..0000000000 --- a/docs/pydoc/config-preview/websearch.yml +++ /dev/null @@ -1,26 +0,0 @@ -loaders: - - type: loaders.CustomPythonLoader - search_path: [../../../haystack/preview/components/websearch] - modules: ["serper_dev"] - ignore_when_discovered: ["__init__"] -processors: - - type: filter - expression: - documented_only: true - do_not_filter_modules: false - skip_empty_modules: true - - type: smart - - type: crossref -renderer: - type: renderers.ReadmePreviewRenderer - excerpt: Web search engine for Haystack. - category_slug: haystack-classes - title: Websearch API - slug: websearch-api - order: 170 - markdown: - descriptive_class_title: false - descriptive_module_title: true - add_method_class_prefix: true - add_member_class_prefix: false - filename: websearch_api.md diff --git a/docs/pydoc/config-preview/whisper.yml b/docs/pydoc/config-preview/whisper.yml deleted file mode 100644 index 1d47d80163..0000000000 --- a/docs/pydoc/config-preview/whisper.yml +++ /dev/null @@ -1,26 +0,0 @@ -loaders: - - type: loaders.CustomPythonLoader - search_path: [../../../haystack/preview/components/audio] - modules: ["whisper_local", "whisper_remote"] - ignore_when_discovered: ["__init__"] -processors: - - type: filter - expression: - documented_only: true - do_not_filter_modules: false - skip_empty_modules: true - - type: smart - - type: crossref -renderer: - type: renderers.ReadmePreviewRenderer - excerpt: Transcribes audio files. - category_slug: haystack-classes - title: WhisperTranscriber API - slug: whisper-transcriber-api - order: 180 - markdown: - descriptive_class_title: false - descriptive_module_title: true - add_method_class_prefix: true - add_member_class_prefix: false - filename: whisper_transcriber_api.md diff --git a/docs/pydoc/config-preview/writer.yml b/docs/pydoc/config-preview/writer.yml deleted file mode 100644 index 9ff0095369..0000000000 --- a/docs/pydoc/config-preview/writer.yml +++ /dev/null @@ -1,26 +0,0 @@ -loaders: - - type: loaders.CustomPythonLoader - search_path: [../../../haystack/preview/components/writers] - modules: ["document_writer"] - ignore_when_discovered: ["__init__"] -processors: - - type: filter - expression: - documented_only: true - do_not_filter_modules: false - skip_empty_modules: true - - type: smart - - type: crossref -renderer: - type: renderers.ReadmePreviewRenderer - excerpt: Writes Documents to a DocumentStore. - category_slug: haystack-classes - title: DocumentWriter API - slug: writer-api - order: 30 - markdown: - descriptive_class_title: false - descriptive_module_title: true - add_method_class_prefix: true - add_member_class_prefix: false - filename: writer_api.md diff --git a/docs/pydoc/config/answer-generator.yml b/docs/pydoc/config/answer-generator.yml deleted file mode 100644 index 285c2d0678..0000000000 --- a/docs/pydoc/config/answer-generator.yml +++ /dev/null @@ -1,26 +0,0 @@ -loaders: - - type: loaders.CustomPythonLoader - search_path: [../../../haystack/nodes/answer_generator] - modules: ["openai"] - ignore_when_discovered: ["__init__"] -processors: - - type: filter - expression: - documented_only: true - do_not_filter_modules: false - skip_empty_modules: true - - type: smart - - type: crossref -renderer: - type: renderers.ReadmeRenderer - excerpt: Reads a set of documents and generates an answer to a question, word by word - category_slug: haystack-classes - title: Answer Generator API - slug: answer-generator-api - order: 5 - markdown: - descriptive_class_title: false - descriptive_module_title: true - add_method_class_prefix: true - add_member_class_prefix: false - filename: answer_generator_api.md diff --git a/docs/pydoc/config/base-generator.yml b/docs/pydoc/config/base-generator.yml deleted file mode 100644 index 41d8fd8907..0000000000 --- a/docs/pydoc/config/base-generator.yml +++ /dev/null @@ -1,27 +0,0 @@ -loaders: - - type: python - search_path: [../../../haystack/nodes/answer_generator] - modules: ["base"] - ignore_when_discovered: ["__init__"] -processors: - - type: filter - expression: - documented_only: true - do_not_filter_modules: false - skip_empty_modules: true - - type: smart - - type: crossref -renderer: - type: renderers.ReadmeRenderer - excerpt: Abstract class for Generators. - category_slug: haystack-classes - title: BaseGenerator API - slug: basegenerator-api - order: 7 - parent_doc_slug: answer-generator-api - markdown: - descriptive_class_title: false - descriptive_module_title: true - add_method_class_prefix: true - add_member_class_prefix: false - filename: basegenerator_api.md diff --git a/docs/pydoc/config/document-store.yml b/docs/pydoc/config/document-store.yml index 4e39d54d3d..bfb056265b 100644 --- a/docs/pydoc/config/document-store.yml +++ b/docs/pydoc/config/document-store.yml @@ -7,6 +7,7 @@ loaders: "es8", "opensearch", "memory", + "mongodb_atlas", "sql", "faiss", "weaviate", diff --git a/docs/pydoc/config/file-converters.yml b/docs/pydoc/config/file-converters.yml index 0ca3866831..819aaa22ef 100644 --- a/docs/pydoc/config/file-converters.yml +++ b/docs/pydoc/config/file-converters.yml @@ -10,7 +10,7 @@ loaders: "json", "markdown", "parsr", - "pdf", + "pdf_xpdf", "pptx", "tika", "txt" diff --git a/docs/pydoc/renderers.py b/docs/pydoc/renderers.py index 6450f666cb..c74a60dc7c 100644 --- a/docs/pydoc/renderers.py +++ b/docs/pydoc/renderers.py @@ -133,16 +133,3 @@ def _frontmatter(self) -> str: slug=self.slug, order=self.order, ) - - -@dataclasses.dataclass -class ReadmePreviewRenderer(ReadmeRenderer): - """ - This custom Renderer behaves just like the ReadmeRenderer but renders docs with the hardcoded version 2.0 to generate correct category ids. - """ - - def _doc_version(self) -> str: - """ - Returns the hardcoded docs version 2.0. - """ - return "v2.0" diff --git a/e2e/conftest.py b/e2e/conftest.py index 41d02927e2..7308073ca3 100644 --- a/e2e/conftest.py +++ b/e2e/conftest.py @@ -25,11 +25,6 @@ def samples_path(): return Path(__file__).parent / "samples" -@pytest.fixture -def preview_samples_path(): - return Path(__file__).parent / "preview" / "test_files" - - @pytest.fixture def docs_all_formats(): return [ diff --git a/e2e/pipelines/test_pipeline_topologies.py b/e2e/pipelines/test_pipeline_topologies.py index 752191d34a..f2d1a1001e 100644 --- a/e2e/pipelines/test_pipeline_topologies.py +++ b/e2e/pipelines/test_pipeline_topologies.py @@ -178,13 +178,7 @@ def test_join_with_rrf(docs): results = p.run(query=query) # list of precalculated expected results - expected_scores = [ - 0.03278688524590164, - 0.03200204813108039, - 0.03200204813108039, - 0.031009615384615385, - 0.031009615384615385, - ] + expected_scores = [1.0, 0.9684979838709676, 0.9684979838709676, 0.9533577533577533, 0.9533577533577533] assert all( doc.score == pytest.approx(expected_scores[idx], abs=1e-3) for idx, doc in enumerate(results["documents"]) ) diff --git a/e2e/pipelines/test_standard_pipelines.py b/e2e/pipelines/test_standard_pipelines.py index f25ddcd13b..3ebc5a4702 100644 --- a/e2e/pipelines/test_standard_pipelines.py +++ b/e2e/pipelines/test_standard_pipelines.py @@ -207,7 +207,7 @@ def test_webqa_pipeline(): search_key = os.environ.get("SERPERDEV_API_KEY") openai_key = os.environ.get("OPENAI_API_KEY") pn = PromptNode( - "text-davinci-003", + "gpt-3.5-turbo-instruct", api_key=openai_key, max_length=256, default_prompt_template="question-answering-with-document-scores", diff --git a/e2e/preview/conftest.py b/e2e/preview/conftest.py deleted file mode 100644 index 3ad8d8a746..0000000000 --- a/e2e/preview/conftest.py +++ /dev/null @@ -1,11 +0,0 @@ -from pathlib import Path - -import pytest -from haystack.preview.testing.test_utils import set_all_seeds - -set_all_seeds(0) - - -@pytest.fixture -def samples_path(): - return Path(__file__).parent / "samples" diff --git a/e2e/preview/pipelines/test_dense_doc_search.py b/e2e/preview/pipelines/test_dense_doc_search.py deleted file mode 100644 index ae2612893d..0000000000 --- a/e2e/preview/pipelines/test_dense_doc_search.py +++ /dev/null @@ -1,84 +0,0 @@ -import json -from pathlib import Path - -from haystack.preview import Pipeline -from haystack.preview.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder -from haystack.preview.components.converters import PyPDFToDocument, TextFileToDocument -from haystack.preview.components.preprocessors import DocumentCleaner, DocumentSplitter -from haystack.preview.components.routers import FileTypeRouter, DocumentJoiner -from haystack.preview.components.writers import DocumentWriter -from haystack.preview.document_stores import InMemoryDocumentStore -from haystack.preview.components.retrievers import InMemoryEmbeddingRetriever - - -def test_dense_doc_search_pipeline(tmp_path, samples_path): - # Create the indexing pipeline - indexing_pipeline = Pipeline() - indexing_pipeline.add_component( - instance=FileTypeRouter(mime_types=["text/plain", "application/pdf"]), name="file_type_router" - ) - indexing_pipeline.add_component(instance=TextFileToDocument(), name="text_file_converter") - indexing_pipeline.add_component(instance=PyPDFToDocument(), name="pdf_file_converter") - indexing_pipeline.add_component(instance=DocumentJoiner(), name="joiner") - indexing_pipeline.add_component(instance=DocumentCleaner(), name="cleaner") - indexing_pipeline.add_component( - instance=DocumentSplitter(split_by="sentence", split_length=250, split_overlap=30), name="splitter" - ) - indexing_pipeline.add_component( - instance=SentenceTransformersDocumentEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"), - name="embedder", - ) - indexing_pipeline.add_component(instance=DocumentWriter(document_store=InMemoryDocumentStore()), name="writer") - - indexing_pipeline.connect("file_type_router.text/plain", "text_file_converter.sources") - indexing_pipeline.connect("file_type_router.application/pdf", "pdf_file_converter.sources") - indexing_pipeline.connect("text_file_converter.documents", "joiner.documents") - indexing_pipeline.connect("pdf_file_converter.documents", "joiner.documents") - indexing_pipeline.connect("joiner.documents", "cleaner.documents") - indexing_pipeline.connect("cleaner.documents", "splitter.documents") - indexing_pipeline.connect("splitter.documents", "embedder.documents") - indexing_pipeline.connect("embedder.documents", "writer.documents") - - # Draw the indexing pipeline - indexing_pipeline.draw(tmp_path / "test_dense_doc_search_indexing_pipeline.png") - - # Serialize the indexing pipeline to JSON - with open(tmp_path / "test_dense_doc_search_indexing_pipeline.json", "w") as f: - print(json.dumps(indexing_pipeline.to_dict(), indent=4)) - json.dump(indexing_pipeline.to_dict(), f) - - # Load the indexing pipeline back - with open(tmp_path / "test_dense_doc_search_indexing_pipeline.json", "r") as f: - indexing_pipeline = Pipeline.from_dict(json.load(f)) - - indexing_result = indexing_pipeline.run({"file_type_router": {"sources": samples_path.iterdir()}}) - filled_document_store = indexing_pipeline.get_component("writer").document_store - - assert indexing_result["writer"]["documents_written"] == 2 - assert filled_document_store.count_documents() == 2 - - # Create the querying pipeline - query_pipeline = Pipeline() - query_pipeline.add_component( - instance=SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"), - name="text_embedder", - ) - query_pipeline.add_component( - instance=InMemoryEmbeddingRetriever(document_store=filled_document_store, top_k=20), name="embedding_retriever" - ) - query_pipeline.connect("text_embedder", "embedding_retriever") - - querying_result = query_pipeline.run({"text_embedder": {"text": "Who lives in Rome?"}}) - assert querying_result["embedding_retriever"]["documents"][0].content == "My name is Giorgio and I live in Rome." - - # Draw the querying pipeline - query_pipeline.draw(tmp_path / "test_dense_doc_search_query_pipeline.png") - - # Serialize the querying pipeline to JSON - with open(tmp_path / "test_dense_doc_search_query_pipeline.json", "w") as f: - print(json.dumps(query_pipeline.to_dict(), indent=4)) - json.dump(query_pipeline.to_dict(), f) - - # Load the querying pipeline back - with open(tmp_path / "test_dense_doc_search_query_pipeline.json", "r") as f: - query_pipeline = Pipeline.from_dict(json.load(f)) diff --git a/e2e/preview/pipelines/test_extractive_qa_pipeline.py b/e2e/preview/pipelines/test_extractive_qa_pipeline.py deleted file mode 100644 index 5af133dd40..0000000000 --- a/e2e/preview/pipelines/test_extractive_qa_pipeline.py +++ /dev/null @@ -1,67 +0,0 @@ -import json - -from haystack.preview import Pipeline, Document -from haystack.preview.document_stores import InMemoryDocumentStore -from haystack.preview.components.retrievers import InMemoryBM25Retriever -from haystack.preview.components.readers import ExtractiveReader - - -def test_extractive_qa_pipeline(tmp_path): - # Create the pipeline - qa_pipeline = Pipeline() - qa_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=InMemoryDocumentStore()), name="retriever") - qa_pipeline.add_component(instance=ExtractiveReader(model_name_or_path="deepset/tinyroberta-squad2"), name="reader") - qa_pipeline.connect("retriever", "reader") - - # Draw the pipeline - qa_pipeline.draw(tmp_path / "test_extractive_qa_pipeline.png") - - # Serialize the pipeline to JSON - with open(tmp_path / "test_bm25_rag_pipeline.json", "w") as f: - print(json.dumps(qa_pipeline.to_dict(), indent=4)) - json.dump(qa_pipeline.to_dict(), f) - - # Load the pipeline back - with open(tmp_path / "test_bm25_rag_pipeline.json", "r") as f: - qa_pipeline = Pipeline.from_dict(json.load(f)) - - # Populate the document store - documents = [ - Document(content="My name is Jean and I live in Paris."), - Document(content="My name is Mark and I live in Berlin."), - Document(content="My name is Giorgio and I live in Rome."), - ] - qa_pipeline.get_component("retriever").document_store.write_documents(documents) - - # Query and assert - questions = ["Who lives in Paris?", "Who lives in Berlin?", "Who lives in Rome?"] - answers_spywords = ["Jean", "Mark", "Giorgio"] - - for question, spyword, doc in zip(questions, answers_spywords, documents): - result = qa_pipeline.run({"retriever": {"query": question}, "reader": {"query": question}}) - - extracted_answers = result["reader"]["answers"] - - # we expect at least one real answer and no_answer - assert len(extracted_answers) > 1 - - # the best answer should contain the spyword - assert spyword in extracted_answers[0].data - - # no_answer - assert extracted_answers[-1].data is None - - # since these questions are easily answerable, the best answer should have higher probability than no_answer - assert extracted_answers[0].probability >= extracted_answers[-1].probability - - for answer in extracted_answers: - assert answer.query == question - - assert hasattr(answer, "probability") - assert hasattr(answer, "start") - assert hasattr(answer, "end") - - assert hasattr(answer, "document") - # the answer is extracted from the correct document - if answer.document is not None: - assert answer.document.id == doc.id diff --git a/e2e/preview/pipelines/test_hybrid_doc_search_pipeline.py b/e2e/preview/pipelines/test_hybrid_doc_search_pipeline.py deleted file mode 100644 index e85db341d8..0000000000 --- a/e2e/preview/pipelines/test_hybrid_doc_search_pipeline.py +++ /dev/null @@ -1,57 +0,0 @@ -import json - -from haystack.preview import Pipeline, Document -from haystack.preview.components.embedders import SentenceTransformersTextEmbedder -from haystack.preview.components.rankers import TransformersSimilarityRanker -from haystack.preview.components.routers.document_joiner import DocumentJoiner -from haystack.preview.document_stores import InMemoryDocumentStore -from haystack.preview.components.retrievers import InMemoryBM25Retriever, InMemoryEmbeddingRetriever - - -def test_hybrid_doc_search_pipeline(tmp_path): - # Create the pipeline - document_store = InMemoryDocumentStore() - hybrid_pipeline = Pipeline() - hybrid_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=document_store), name="bm25_retriever") - hybrid_pipeline.add_component( - instance=SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"), - name="text_embedder", - ) - hybrid_pipeline.add_component( - instance=InMemoryEmbeddingRetriever(document_store=document_store), name="embedding_retriever" - ) - hybrid_pipeline.add_component(instance=DocumentJoiner(), name="joiner") - hybrid_pipeline.add_component(instance=TransformersSimilarityRanker(top_k=20), name="ranker") - - hybrid_pipeline.connect("bm25_retriever", "joiner") - hybrid_pipeline.connect("text_embedder", "embedding_retriever") - hybrid_pipeline.connect("embedding_retriever", "joiner") - hybrid_pipeline.connect("joiner", "ranker") - - # Draw the pipeline - hybrid_pipeline.draw(tmp_path / "test_hybrid_doc_search_pipeline.png") - - # Serialize the pipeline to JSON - with open(tmp_path / "test_hybrid_doc_search_pipeline.json", "w") as f: - print(json.dumps(hybrid_pipeline.to_dict(), indent=4)) - json.dump(hybrid_pipeline.to_dict(), f) - - # Load the pipeline back - with open(tmp_path / "test_hybrid_doc_search_pipeline.json", "r") as f: - hybrid_pipeline = Pipeline.from_dict(json.load(f)) - - # Populate the document store - documents = [ - Document(content="My name is Jean and I live in Paris."), - Document(content="My name is Mark and I live in Berlin."), - Document(content="My name is Mario and I live in the capital of Italy."), - Document(content="My name is Giorgio and I live in Rome."), - ] - hybrid_pipeline.get_component("bm25_retriever").document_store.write_documents(documents) - - query = "Who lives in Rome?" - result = hybrid_pipeline.run( - {"bm25_retriever": {"query": query}, "text_embedder": {"text": query}, "ranker": {"query": query}} - ) - assert result["ranker"]["documents"][0].content == "My name is Giorgio and I live in Rome." - assert result["ranker"]["documents"][1].content == "My name is Mario and I live in the capital of Italy." diff --git a/e2e/preview/pipelines/test_preprocessing_pipeline.py b/e2e/preview/pipelines/test_preprocessing_pipeline.py deleted file mode 100644 index 2f16f1d993..0000000000 --- a/e2e/preview/pipelines/test_preprocessing_pipeline.py +++ /dev/null @@ -1,89 +0,0 @@ -import json - -from haystack.preview import Pipeline -from haystack.preview.components.embedders import SentenceTransformersDocumentEmbedder -from haystack.preview.components.converters import TextFileToDocument -from haystack.preview.components.preprocessors import DocumentSplitter, DocumentCleaner -from haystack.preview.components.classifiers import DocumentLanguageClassifier -from haystack.preview.components.routers import FileTypeRouter, MetadataRouter -from haystack.preview.components.writers import DocumentWriter -from haystack.preview.document_stores import InMemoryDocumentStore - - -def test_preprocessing_pipeline(tmp_path): - # Create the pipeline and its components - document_store = InMemoryDocumentStore() - preprocessing_pipeline = Pipeline() - preprocessing_pipeline.add_component(instance=FileTypeRouter(mime_types=["text/plain"]), name="file_type_router") - preprocessing_pipeline.add_component(instance=TextFileToDocument(), name="text_file_converter") - preprocessing_pipeline.add_component(instance=DocumentLanguageClassifier(), name="language_classifier") - preprocessing_pipeline.add_component( - instance=MetadataRouter(rules={"en": {"field": "language", "operator": "==", "value": "en"}}), name="router" - ) - preprocessing_pipeline.add_component(instance=DocumentCleaner(), name="cleaner") - preprocessing_pipeline.add_component( - instance=DocumentSplitter(split_by="sentence", split_length=1), name="splitter" - ) - preprocessing_pipeline.add_component( - instance=SentenceTransformersDocumentEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"), - name="embedder", - ) - preprocessing_pipeline.add_component(instance=DocumentWriter(document_store=document_store), name="writer") - preprocessing_pipeline.connect("file_type_router.text/plain", "text_file_converter.sources") - preprocessing_pipeline.connect("text_file_converter.documents", "language_classifier.documents") - preprocessing_pipeline.connect("language_classifier.documents", "router.documents") - preprocessing_pipeline.connect("router.en", "cleaner.documents") - preprocessing_pipeline.connect("cleaner.documents", "splitter.documents") - preprocessing_pipeline.connect("splitter.documents", "embedder.documents") - preprocessing_pipeline.connect("embedder.documents", "writer.documents") - - # Draw the pipeline - preprocessing_pipeline.draw(tmp_path / "test_preprocessing_pipeline.png") - - # Serialize the pipeline to JSON - with open(tmp_path / "test_preprocessing_pipeline.json", "w") as f: - print(json.dumps(preprocessing_pipeline.to_dict(), indent=4)) - json.dump(preprocessing_pipeline.to_dict(), f) - - # Load the pipeline back - with open(tmp_path / "test_preprocessing_pipeline.json", "r") as f: - preprocessing_pipeline = Pipeline.from_dict(json.load(f)) - - # Write a txt file - with open(tmp_path / "test_file_english.txt", "w") as f: - f.write( - "This is an english sentence. There is more to it. It's a long text." - "Spans multiple lines." - "" - "Even contains empty lines. And extra whitespaces." - ) - - # Write a txt file - with open(tmp_path / "test_file_german.txt", "w") as f: - f.write("Ein deutscher Satz ohne Verb.") - - # Add two txt files and one non-txt file - paths = [ - tmp_path / "test_file_english.txt", - tmp_path / "test_file_german.txt", - tmp_path / "test_preprocessing_pipeline.json", - ] - - result = preprocessing_pipeline.run({"file_type_router": {"sources": paths}}) - - assert result["writer"]["documents_written"] == 6 - filled_document_store = preprocessing_pipeline.get_component("writer").document_store - assert filled_document_store.count_documents() == 6 - - # Check preprocessed texts - stored_documents = filled_document_store.filter_documents() - expected_texts = [ - "This is an english sentence.", - " There is more to it.", - " It's a long text.", - "Spans multiple lines.", - "Even contains empty lines.", - " And extra whitespaces.", - ] - assert expected_texts == [document.content for document in stored_documents] - assert all(document.meta["language"] == "en" for document in stored_documents) diff --git a/e2e/preview/pipelines/test_rag_pipelines.py b/e2e/preview/pipelines/test_rag_pipelines.py deleted file mode 100644 index d05c0fabea..0000000000 --- a/e2e/preview/pipelines/test_rag_pipelines.py +++ /dev/null @@ -1,159 +0,0 @@ -import os -import json -import pytest - -from haystack.preview import Pipeline, Document -from haystack.preview.document_stores import InMemoryDocumentStore -from haystack.preview.components.writers import DocumentWriter -from haystack.preview.components.retrievers import InMemoryBM25Retriever, InMemoryEmbeddingRetriever -from haystack.preview.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder -from haystack.preview.components.generators import GPTGenerator -from haystack.preview.components.builders.answer_builder import AnswerBuilder -from haystack.preview.components.builders.prompt_builder import PromptBuilder - - -@pytest.mark.skipif( - not os.environ.get("OPENAI_API_KEY", None), - reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", -) -def test_bm25_rag_pipeline(tmp_path): - # Create the RAG pipeline - prompt_template = """ - Given these documents, answer the question.\nDocuments: - {% for doc in documents %} - {{ doc.content }} - {% endfor %} - - \nQuestion: {{question}} - \nAnswer: - """ - rag_pipeline = Pipeline() - rag_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=InMemoryDocumentStore()), name="retriever") - rag_pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder") - rag_pipeline.add_component(instance=GPTGenerator(api_key=os.environ.get("OPENAI_API_KEY")), name="llm") - rag_pipeline.add_component(instance=AnswerBuilder(), name="answer_builder") - rag_pipeline.connect("retriever", "prompt_builder.documents") - rag_pipeline.connect("prompt_builder", "llm") - rag_pipeline.connect("llm.replies", "answer_builder.replies") - rag_pipeline.connect("llm.metadata", "answer_builder.metadata") - rag_pipeline.connect("retriever", "answer_builder.documents") - - # Draw the pipeline - rag_pipeline.draw(tmp_path / "test_bm25_rag_pipeline.png") - - # Serialize the pipeline to JSON - with open(tmp_path / "test_bm25_rag_pipeline.json", "w") as f: - json.dump(rag_pipeline.to_dict(), f) - - # Load the pipeline back - with open(tmp_path / "test_bm25_rag_pipeline.json", "r") as f: - rag_pipeline = Pipeline.from_dict(json.load(f)) - - # Populate the document store - documents = [ - Document(content="My name is Jean and I live in Paris."), - Document(content="My name is Mark and I live in Berlin."), - Document(content="My name is Giorgio and I live in Rome."), - ] - rag_pipeline.get_component("retriever").document_store.write_documents(documents) - - # Query and assert - questions = ["Who lives in Paris?", "Who lives in Berlin?", "Who lives in Rome?"] - answers_spywords = ["Jean", "Mark", "Giorgio"] - - for question, spyword in zip(questions, answers_spywords): - result = rag_pipeline.run( - { - "retriever": {"query": question}, - "prompt_builder": {"question": question}, - "answer_builder": {"query": question}, - } - ) - - assert len(result["answer_builder"]["answers"]) == 1 - generated_answer = result["answer_builder"]["answers"][0] - assert spyword in generated_answer.data - assert generated_answer.query == question - assert hasattr(generated_answer, "documents") - assert hasattr(generated_answer, "metadata") - - -@pytest.mark.skipif( - not os.environ.get("OPENAI_API_KEY", None), - reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", -) -def test_embedding_retrieval_rag_pipeline(tmp_path): - # Create the RAG pipeline - prompt_template = """ - Given these documents, answer the question.\nDocuments: - {% for doc in documents %} - {{ doc.content }} - {% endfor %} - - \nQuestion: {{question}} - \nAnswer: - """ - rag_pipeline = Pipeline() - rag_pipeline.add_component( - instance=SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"), - name="text_embedder", - ) - rag_pipeline.add_component( - instance=InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore()), name="retriever" - ) - rag_pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder") - rag_pipeline.add_component(instance=GPTGenerator(api_key=os.environ.get("OPENAI_API_KEY")), name="llm") - rag_pipeline.add_component(instance=AnswerBuilder(), name="answer_builder") - rag_pipeline.connect("text_embedder", "retriever") - rag_pipeline.connect("retriever", "prompt_builder.documents") - rag_pipeline.connect("prompt_builder", "llm") - rag_pipeline.connect("llm.replies", "answer_builder.replies") - rag_pipeline.connect("llm.metadata", "answer_builder.metadata") - rag_pipeline.connect("retriever", "answer_builder.documents") - - # Draw the pipeline - rag_pipeline.draw(tmp_path / "test_embedding_rag_pipeline.png") - - # Serialize the pipeline to JSON - with open(tmp_path / "test_embedding_rag_pipeline.json", "w") as f: - json.dump(rag_pipeline.to_dict(), f) - - # Load the pipeline back - with open(tmp_path / "test_embedding_rag_pipeline.json", "r") as f: - rag_pipeline = Pipeline.from_dict(json.load(f)) - - # Populate the document store - documents = [ - Document(content="My name is Jean and I live in Paris."), - Document(content="My name is Mark and I live in Berlin."), - Document(content="My name is Giorgio and I live in Rome."), - ] - document_store = rag_pipeline.get_component("retriever").document_store - indexing_pipeline = Pipeline() - indexing_pipeline.add_component( - instance=SentenceTransformersDocumentEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"), - name="document_embedder", - ) - indexing_pipeline.add_component(instance=DocumentWriter(document_store=document_store), name="document_writer") - indexing_pipeline.connect("document_embedder", "document_writer") - indexing_pipeline.run({"document_embedder": {"documents": documents}}) - - # Query and assert - questions = ["Who lives in Paris?", "Who lives in Berlin?", "Who lives in Rome?"] - answers_spywords = ["Jean", "Mark", "Giorgio"] - - for question, spyword in zip(questions, answers_spywords): - result = rag_pipeline.run( - { - "text_embedder": {"text": question}, - "prompt_builder": {"question": question}, - "answer_builder": {"query": question}, - } - ) - - assert len(result["answer_builder"]["answers"]) == 1 - generated_answer = result["answer_builder"]["answers"][0] - assert spyword in generated_answer.data - assert generated_answer.query == question - assert hasattr(generated_answer, "documents") - assert hasattr(generated_answer, "metadata") diff --git a/e2e/preview/samples/doc_1.txt b/e2e/preview/samples/doc_1.txt deleted file mode 100644 index 1d3da15eb9..0000000000 --- a/e2e/preview/samples/doc_1.txt +++ /dev/null @@ -1 +0,0 @@ -My name is Giorgio and I live in Rome. diff --git a/e2e/preview/samples/sample_pdf_1.pdf b/e2e/preview/samples/sample_pdf_1.pdf deleted file mode 100644 index 87259b897f..0000000000 Binary files a/e2e/preview/samples/sample_pdf_1.pdf and /dev/null differ diff --git a/examples/basic_faq_pipeline.py b/examples/basic_faq_pipeline.py deleted file mode 100644 index e198ca5367..0000000000 --- a/examples/basic_faq_pipeline.py +++ /dev/null @@ -1,76 +0,0 @@ -# Disable pylint errors for logging basicConfig -# pylint: disable=no-logging-basicconfig -import logging - -import pandas as pd - -from haystack.document_stores import ElasticsearchDocumentStore -from haystack.nodes import EmbeddingRetriever -from haystack.nodes.other.docs2answers import Docs2Answers -from haystack.pipelines import Pipeline -from haystack.utils import fetch_archive_from_http, launch_es, print_answers - -logging.basicConfig(format="%(levelname)s - %(name)s - %(message)s", level=logging.WARNING) -logging.getLogger("haystack").setLevel(logging.INFO) - - -def basic_faq_pipeline(): - document_store = ElasticsearchDocumentStore( - host="localhost", - username="", - password="", - index="example-document", - embedding_field="question_emb", - embedding_dim=384, - excluded_meta_data=["question_emb"], - similarity="cosine", - ) - - retriever = EmbeddingRetriever( - document_store=document_store, - embedding_model="sentence-transformers/all-MiniLM-L6-v2", - use_gpu=True, - scale_score=False, - ) - - doc_to_answers = Docs2Answers() - - doc_dir = "data/basic_faq_pipeline" - s3_url = "https://core-engineering.s3.eu-central-1.amazonaws.com/public/scripts/small_faq_covid.csv1.zip" - fetch_archive_from_http(url=s3_url, output_dir=doc_dir) - - df = pd.read_csv(f"{doc_dir}/small_faq_covid.csv") - - # Minimal cleaning - df.fillna(value="", inplace=True) - df["question"] = df["question"].apply(lambda x: x.strip()) - print(df.head()) - - # Get embeddings for our questions from the FAQs - questions = list(df["question"].values) - df["question_emb"] = retriever.embed_queries(queries=questions).tolist() - df = df.rename(columns={"question": "content"}) - - # Convert Dataframe to list of dicts and index them in our DocumentStore - docs_to_index = df.to_dict(orient="records") - document_store.write_documents(docs_to_index) - document_store.update_embeddings(retriever) - - # Initialize a Pipeline (this time without a reader) and ask questions - pipeline = Pipeline() - pipeline.add_node(component=retriever, name="Retriever", inputs=["Query"]) - pipeline.add_node(component=doc_to_answers, name="Docs2Answers", inputs=["Retriever"]) - - # Ask a question - prediction = pipeline.run(query="How is the virus spreading?", params={"Retriever": {"top_k": 10}}) - - print_answers(prediction, details="medium") - - # Remove the index once we're done to save space - document_store.delete_index(index="example-document") - return prediction - - -if __name__ == "__main__": - launch_es() - basic_faq_pipeline() diff --git a/examples/basic_qa_pipeline.py b/examples/basic_qa_pipeline.py deleted file mode 100644 index 97988627ee..0000000000 --- a/examples/basic_qa_pipeline.py +++ /dev/null @@ -1,79 +0,0 @@ -# Disable pylint errors for logging basicConfig -# pylint: disable=no-logging-basicconfig -import logging -from pathlib import Path - -from haystack.document_stores import ElasticsearchDocumentStore -from haystack.nodes import BM25Retriever, FARMReader -from haystack.nodes.file_classifier import FileTypeClassifier -from haystack.nodes.file_converter import TextConverter -from haystack.nodes.preprocessor import PreProcessor -from haystack.pipelines import Pipeline -from haystack.utils import fetch_archive_from_http, launch_es, print_answers - -logging.basicConfig(format="%(levelname)s - %(name)s - %(message)s", level=logging.WARNING) -logging.getLogger("haystack").setLevel(logging.INFO) - - -def basic_qa_pipeline(): - # Initialize a DocumentStore - document_store = ElasticsearchDocumentStore(host="localhost", username="", password="", index="example-document") - - # fetch, pre-process and write documents - doc_dir = "data/basic_qa_pipeline" - s3_url = "https://core-engineering.s3.eu-central-1.amazonaws.com/public/scripts/wiki_gameofthrones_txt1.zip" - fetch_archive_from_http(url=s3_url, output_dir=doc_dir) - - file_paths = [p for p in Path(doc_dir).glob("**/*")] - files_metadata = [{"name": path.name} for path in file_paths] - - # Indexing Pipeline - indexing_pipeline = Pipeline() - - # Makes sure the file is a TXT file (FileTypeClassifier node) - classifier = FileTypeClassifier() - indexing_pipeline.add_node(classifier, name="Classifier", inputs=["File"]) - - # Converts a file into text and performs basic cleaning (TextConverter node) - text_converter = TextConverter(remove_numeric_tables=True) - indexing_pipeline.add_node(text_converter, name="Text_converter", inputs=["Classifier.output_1"]) - - # - Pre-processes the text by performing splits and adding metadata to the text (Preprocessor node) - preprocessor = PreProcessor( - clean_whitespace=True, - clean_empty_lines=True, - split_length=100, - split_overlap=50, - split_respect_sentence_boundary=True, - ) - indexing_pipeline.add_node(preprocessor, name="Preprocessor", inputs=["Text_converter"]) - - # - Writes the resulting documents into the document store - indexing_pipeline.add_node(document_store, name="Document_Store", inputs=["Preprocessor"]) - - # Then we run it with the documents and their metadata as input - indexing_pipeline.run(file_paths=file_paths, meta=files_metadata) - - # Initialize Retriever & Reader - retriever = BM25Retriever(document_store=document_store) - reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2", use_gpu=True) - - # Query Pipeline - pipeline = Pipeline() - pipeline.add_node(component=retriever, name="Retriever", inputs=["Query"]) - pipeline.add_node(component=reader, name="Reader", inputs=["Retriever"]) - - prediction = pipeline.run( - query="Who is the father of Arya Stark?", params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}} - ) - - print_answers(prediction, details="minimum") - - # Remove the index once we're done to save space - document_store.delete_index(index="example-document") - return prediction - - -if __name__ == "__main__": - launch_es() - basic_qa_pipeline() diff --git a/examples/getting_started.py b/examples/getting_started.py deleted file mode 100644 index af0d5012fa..0000000000 --- a/examples/getting_started.py +++ /dev/null @@ -1,34 +0,0 @@ -from haystack.document_stores import InMemoryDocumentStore -from haystack.utils import build_pipeline, add_example_data, print_answers - - -def getting_started(provider, API_KEY): - """ - This getting_started example shows you how to use LLMs with your data with a technique called Retrieval Augmented Generation - RAG. - - :param provider: We are model agnostic :) Here, you can choose from: "anthropic", "cohere", "huggingface", and "openai". - :param API_KEY: The API key matching the provider. - - """ - - # We support many different databases. Here we load a simple and lightweight in-memory database. - document_store = InMemoryDocumentStore(use_bm25=True) - - # Pipelines are the main abstraction in Haystack, they connect components like LLMs and databases. - pipeline = build_pipeline(provider, API_KEY, document_store) - - # Download and add Game of Thrones TXT articles to Haystack's database. - # You can also provide a folder with your local documents. - # You might need to install additional dependencies - look inside the function for more information. - add_example_data(document_store, "data/GoT_getting_started") - - # Ask a question on the data you just added. - result = pipeline.run(query="Who is the father of Arya Stark?") - - # For details such as which documents were used to generate the answer, look into the object. - print_answers(result, details="medium") - return result - - -if __name__ == "__main__": - getting_started(provider="openai", API_KEY="ADD KEY HERE") diff --git a/examples/hybrid_search_faq_pipeline.py b/examples/hybrid_search_faq_pipeline.py deleted file mode 100644 index d4fcba6cf0..0000000000 --- a/examples/hybrid_search_faq_pipeline.py +++ /dev/null @@ -1,85 +0,0 @@ -# import logging - -import pandas as pd - -from haystack.document_stores import ElasticsearchDocumentStore -from haystack.nodes import EmbeddingRetriever, BM25Retriever, JoinDocuments, SentenceTransformersRanker -from haystack.nodes.other.docs2answers import Docs2Answers -from haystack.utils import launch_es, print_answers, fetch_archive_from_http -from haystack.pipelines import Pipeline - -# logging.basicConfig(format="%(levelname)s - %(name)s - %(message)s", level=logging.WARNING) -# logging.getLogger("haystack").setLevel(logging.INFO) - - -def hybrid_search_faq_pipeline(): - document_store = ElasticsearchDocumentStore( - host="localhost", - username="", - password="", - index="document", - embedding_field="question_emb", - embedding_dim=384, - excluded_meta_data=["question_emb"], - similarity="cosine", - ) - - sparse_retriever = BM25Retriever(document_store=document_store) - dense_retriever = EmbeddingRetriever( - document_store=document_store, - embedding_model="sentence-transformers/all-MiniLM-L6-v2", - use_gpu=True, - scale_score=False, - ) - join_documents = JoinDocuments(join_mode="reciprocal_rank_fusion") - rerank = SentenceTransformersRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2") - - doc_to_answers = Docs2Answers() - - doc_dir = "data/basic_faq_pipeline" - s3_url = "https://core-engineering.s3.eu-central-1.amazonaws.com/public/scripts/small_faq_covid.csv1.zip" - fetch_archive_from_http(url=s3_url, output_dir=doc_dir) - - df = pd.read_csv(f"{doc_dir}/small_faq_covid.csv") - - # Minimal cleaning - df.fillna(value="", inplace=True) - df["question"] = df["question"].apply(lambda x: x.strip()) - print(df.head()) - - # Get embeddings for our questions from the FAQs - questions = list(df["question"].values) - df["question_emb"] = dense_retriever.embed_queries(queries=questions).tolist() - df = df.rename(columns={"question": "content"}) - - # Convert Dataframe to list of dicts and index them in our DocumentStore - docs_to_index = df.to_dict(orient="records") - document_store.write_documents(docs_to_index) - document_store.update_embeddings(retriever=dense_retriever) - - # Initialize a Pipeline (this time without a reader) and ask questions - pipeline = Pipeline() - pipeline.add_node(component=sparse_retriever, name="SparseRetriever", inputs=["Query"]) - pipeline.add_node(component=dense_retriever, name="DenseRetriever", inputs=["Query"]) - pipeline.add_node(component=join_documents, name="JoinDocuments", inputs=["SparseRetriever", "DenseRetriever"]) - pipeline.add_node(component=rerank, name="ReRanker", inputs=["JoinDocuments"]) - pipeline.add_node(component=doc_to_answers, name="Docs2Answers", inputs=["ReRanker"]) - - # Ask a question - prediction = pipeline.run( - query="How is the virus spreading?", - params={ - "SparseRetriever": {"top_k": 10}, - "DenseRetriever": {"top_k": 10}, - "JoinDocuments": {"top_k_join": 15}, - "ReRanker": {"top_k": 5}, - }, - ) - - print_answers(prediction, details="medium") - return prediction - - -if __name__ == "__main__": - launch_es() - hybrid_search_faq_pipeline() diff --git a/examples/preview/retrievers/in_memory_bm25_documentsearch.py b/examples/preview/retrievers/in_memory_bm25_documentsearch.py deleted file mode 100644 index e153bbefa6..0000000000 --- a/examples/preview/retrievers/in_memory_bm25_documentsearch.py +++ /dev/null @@ -1,28 +0,0 @@ -from haystack.preview import Document -from haystack.preview.components.retrievers import InMemoryBM25Retriever -from haystack.preview.document_stores import InMemoryDocumentStore -from haystack.preview.pipeline import Pipeline - -# Create components and a query pipeline -document_store = InMemoryDocumentStore() -retriever = InMemoryBM25Retriever(document_store=document_store) - -pipeline = Pipeline() -pipeline.add_component(instance=retriever, name="retriever") - -# Add Documents -documents = [ - Document(content="There are over 7,000 languages spoken around the world today."), - Document( - content="Elephants have been observed to behave in a way that indicates a high level of self-awareness, such as recognizing themselves in mirrors." - ), - Document( - content="In certain parts of the world, like the Maldives, Puerto Rico, and San Diego, you can witness the phenomenon of bioluminescent waves." - ), -] -document_store.write_documents(documents) - -# Run the pipeline -result = pipeline.run(data={"retriever": {"query": "How many languages are there?"}}) - -print(result["retriever"]["documents"][0]) diff --git a/examples/preview/retrievers/in_memory_bm25_rag.py b/examples/preview/retrievers/in_memory_bm25_rag.py deleted file mode 100644 index ebb9ec5b00..0000000000 --- a/examples/preview/retrievers/in_memory_bm25_rag.py +++ /dev/null @@ -1,53 +0,0 @@ -import os - -from haystack.preview import Document -from haystack.preview import Pipeline -from haystack.preview.components.builders.answer_builder import AnswerBuilder -from haystack.preview.components.builders.prompt_builder import PromptBuilder -from haystack.preview.components.generators import GPTGenerator -from haystack.preview.components.retrievers import InMemoryBM25Retriever -from haystack.preview.document_stores import InMemoryDocumentStore - -# Create a RAG query pipeline -prompt_template = """ - Given these documents, answer the question.\nDocuments: - {% for doc in documents %} - {{ doc.content }} - {% endfor %} - - \nQuestion: {{question}} - \nAnswer: - """ - -rag_pipeline = Pipeline() -rag_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=InMemoryDocumentStore()), name="retriever") -rag_pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder") -rag_pipeline.add_component(instance=GPTGenerator(api_key=os.environ.get("OPENAI_API_KEY")), name="llm") -rag_pipeline.add_component(instance=AnswerBuilder(), name="answer_builder") -rag_pipeline.connect("retriever", "prompt_builder.documents") -rag_pipeline.connect("prompt_builder", "llm") -rag_pipeline.connect("llm.replies", "answer_builder.replies") -rag_pipeline.connect("llm.metadata", "answer_builder.metadata") -rag_pipeline.connect("retriever", "answer_builder.documents") - -# Draw the pipeline -rag_pipeline.draw("./rag_pipeline.png") - -# Add Documents -documents = [ - Document(content="There are over 7,000 languages spoken around the world today."), - Document( - content="Elephants have been observed to behave in a way that indicates a high level of self-awareness, such as recognizing themselves in mirrors." - ), - Document( - content="In certain parts of the world, like the Maldives, Puerto Rico, and San Diego, you can witness the phenomenon of bioluminescent waves." - ), -] -rag_pipeline.get_component("retriever").document_store.write_documents(documents) - -# Run the pipeline -question = "How many languages are there?" -result = rag_pipeline.run( - {"retriever": {"query": question}, "prompt_builder": {"question": question}, "answer_builder": {"query": question}} -) -print(result["answer_builder"]["answers"][0]) diff --git a/examples/test_basic_faq_pipeline.py b/examples/test_basic_faq_pipeline.py deleted file mode 100644 index b637ad7223..0000000000 --- a/examples/test_basic_faq_pipeline.py +++ /dev/null @@ -1,19 +0,0 @@ -from examples.basic_faq_pipeline import basic_faq_pipeline - -from haystack.schema import Answer - - -def test_basic_faq_pipeline(): - prediction = basic_faq_pipeline() - - assert prediction is not None - assert prediction["query"] == "How is the virus spreading?" - - assert len(prediction["answers"]) == 10 # top-k of Retriever - assert type(prediction["answers"][0]) == Answer - assert ( - prediction["answers"][0].answer - == """This virus was first detected in Wuhan City, Hubei Province, China. The first infections were linked to a live animal market, but the virus is now spreading from person-to-person. It’s important to note that person-to-person spread can happen on a continuum. Some viruses are highly contagious (like measles), while other viruses are less so.\n\nThe virus that causes COVID-19 seems to be spreading easily and sustainably in the community (“community spread”) in some affected geographic areas. Community spread means people have been infected with the virus in an area, including some who are not sure how or where they became infected.\n\nLearn what is known about the spread of newly emerged coronaviruses.""" - ) - assert prediction["answers"][0].score <= 1 - assert prediction["answers"][0].score >= 0 diff --git a/examples/test_basic_qa_pipeline.py b/examples/test_basic_qa_pipeline.py deleted file mode 100644 index d538979822..0000000000 --- a/examples/test_basic_qa_pipeline.py +++ /dev/null @@ -1,23 +0,0 @@ -from examples.basic_qa_pipeline import basic_qa_pipeline - -from haystack.schema import Answer, Document - - -def test_basic_qa_pipeline(): - prediction = basic_qa_pipeline() - - assert prediction is not None - assert prediction["query"] == "Who is the father of Arya Stark?" - - assert len(prediction["answers"]) == 5 # top-k of Reader - assert type(prediction["answers"][0]) == Answer - assert prediction["answers"][0].answer == "Ned" - assert prediction["answers"][0].score <= 1 - assert prediction["answers"][0].score >= 0 - assert prediction["answers"][0].meta["name"] == "43_Arya_Stark.txt" - - assert len(prediction["documents"]) == 10 # top-k of Retriever - assert type(prediction["documents"][0]) == Document - assert prediction["documents"][0].score <= 1 - assert prediction["documents"][0].score >= 0 - assert prediction["documents"][0].meta["name"] == "450_Baelor.txt" diff --git a/examples/test_getting_started.py b/examples/test_getting_started.py deleted file mode 100644 index ee4b99aa98..0000000000 --- a/examples/test_getting_started.py +++ /dev/null @@ -1,26 +0,0 @@ -import os - -import pytest - -from examples.getting_started import getting_started -from haystack.schema import Answer, Document - - -@pytest.mark.parametrize("provider", ["cohere", "huggingface", "openai"]) -def test_getting_started(provider): - if provider == "anthropic": - api_key = os.environ.get("ANTHROPIC_API_KEY", "") - elif provider == "cohere": - api_key = os.environ.get("COHERE_API_KEY", "") - elif provider == "huggingface": - api_key = os.environ.get("HUGGINGFACE_API_KEY", "") - elif provider == "openai": - api_key = os.environ.get("OPENAI_API_KEY", "") - - if api_key: - result = getting_started(provider=provider, API_KEY=api_key) - - # Testing only for functionality. Since model predictions from APIs might change, we cannot test those directly. - assert isinstance(result, dict) - assert type(result["answers"][0]) == Answer - assert type(result["documents"][0]) == Document diff --git a/examples/web_lfqa.py b/examples/web_lfqa.py index ff5dbe15e0..cfdf81c602 100644 --- a/examples/web_lfqa.py +++ b/examples/web_lfqa.py @@ -21,7 +21,7 @@ """ prompt_node = PromptNode( - "text-davinci-003", default_prompt_template=PromptTemplate(prompt_text), api_key=openai_key, max_length=256 + "gpt-3.5-turbo-instruct", default_prompt_template=PromptTemplate(prompt_text), api_key=openai_key, max_length=256 ) web_retriever = WebRetriever(api_key=search_key, top_search_results=5, mode="preprocessed_documents", top_k=30) diff --git a/examples/web_qa.py b/examples/web_qa.py index 352d2d226d..adc7f19dba 100644 --- a/examples/web_qa.py +++ b/examples/web_qa.py @@ -12,7 +12,7 @@ raise ValueError("Please set the OPENAI_API_KEY environment variable") prompt_node = PromptNode( - "text-davinci-003", + "gpt-3.5-turbo-instruct", api_key=openai_key, max_length=256, default_prompt_template="question-answering-with-document-scores", diff --git a/haystack/agents/base.py b/haystack/agents/base.py index da6d1d61ba..f8b2504128 100644 --- a/haystack/agents/base.py +++ b/haystack/agents/base.py @@ -19,7 +19,6 @@ BaseStandardPipeline, ExtractiveQAPipeline, DocumentSearchPipeline, - GenerativeQAPipeline, SearchSummarizationPipeline, FAQPipeline, TranslationWrapperPipeline, @@ -57,7 +56,6 @@ def __init__( Pipeline, ExtractiveQAPipeline, DocumentSearchPipeline, - GenerativeQAPipeline, SearchSummarizationPipeline, FAQPipeline, TranslationWrapperPipeline, diff --git a/haystack/document_stores/__init__.py b/haystack/document_stores/__init__.py index 208d86e5d0..6dbe812d54 100644 --- a/haystack/document_stores/__init__.py +++ b/haystack/document_stores/__init__.py @@ -13,3 +13,4 @@ from haystack.document_stores.faiss import FAISSDocumentStore from haystack.document_stores.pinecone import PineconeDocumentStore from haystack.document_stores.weaviate import WeaviateDocumentStore +from haystack.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore diff --git a/haystack/document_stores/faiss.py b/haystack/document_stores/faiss.py index 258b874553..a666439d07 100644 --- a/haystack/document_stores/faiss.py +++ b/haystack/document_stores/faiss.py @@ -324,6 +324,18 @@ def write_documents( def _create_document_field_map(self) -> Dict: return {self.index: self.embedding_field} + def _validate_embedding_dimension(self, retriever: DenseRetriever, index: Optional[str] = None): + """ + Verify if the embedding dimension set in the document store and embedding dimension of the retriever are the same. + This check is done before calculating embeddings for all documents. + :param retriever: Retriever to use to get embeddings for text + :param index: Index name for which embeddings are to be updated. If set to None, the default self.index is used. + :return: None + """ + first_document = self.get_all_documents(index=index)[0] + embeddings = retriever.embed_documents([first_document]) + self._validate_embeddings_shape(embeddings=embeddings, num_documents=1, embedding_dim=self.embedding_dim) + def update_embeddings( self, retriever: DenseRetriever, @@ -373,6 +385,8 @@ def update_embeddings( logger.warning("Calling DocumentStore.update_embeddings() on an empty index") return + self._validate_embedding_dimension(retriever, index) + logger.info("Updating embeddings for %s docs...", document_count) vector_id = self.faiss_indexes[index].ntotal diff --git a/haystack/document_stores/mongodb_atlas.py b/haystack/document_stores/mongodb_atlas.py new file mode 100644 index 0000000000..e7acd13e6a --- /dev/null +++ b/haystack/document_stores/mongodb_atlas.py @@ -0,0 +1,595 @@ +import re +from typing import Dict, Generator, List, Optional, Union + +import numpy as np +from tqdm import tqdm + +from haystack import __version__ as haystack_version +from haystack.document_stores import BaseDocumentStore +from haystack.errors import DocumentStoreError +from haystack.nodes.retriever import DenseRetriever +from haystack.schema import Document, FilterType +from haystack.utils import get_batches_from_generator + +from ..lazy_imports import LazyImport +from .mongodb_filters import mongo_filter_converter + +with LazyImport("Run 'pip install farm-haystack[mongodb]'") as mongodb_import: + import pymongo + from pymongo import InsertOne, ReplaceOne, UpdateOne + from pymongo.collection import Collection as MongoCollection + from pymongo.driver_info import DriverInfo + +METRIC_TYPES = ["euclidean", "cosine", "dotProduct"] +DEFAULT_BATCH_SIZE = 50 + + +class MongoDBAtlasDocumentStore(BaseDocumentStore): + def __init__( + self, + mongo_connection_string: Optional[str] = None, + database_name: Optional[str] = None, + collection_name: Optional[str] = None, + vector_search_index: Optional[str] = None, + embedding_dim: int = 768, + return_embedding: bool = False, + similarity: str = "cosine", + embedding_field: str = "embedding", + progress_bar: bool = True, + duplicate_documents: str = "overwrite", + recreate_index: bool = False, + ): + """ + Document Store using MongoDB Atlas as a backend (https://www.mongodb.com/docs/atlas/getting-started/). + It is compatible with EmbeddingRetriever and filters. + + :param mongo_connection_string: MongoDB Atlas connection string in the format: "mongodb+srv://{mongo_atlas_username}:{mongo_atlas_password}@{mongo_atlas_host}/?{mongo_atlas_params_string}". + :param database_name: Name of the database to use. + :param collection_name: Name of the collection to use. + :param vector_search_index: The name of the index to use for vector search. To use the search index it must have been created in the Atlas web UI before. None by default. + :param embedding_dim: Dimensionality of embeddings, 768 by default. + :param return_embedding: Whether to return document embeddings when returning documents. + :param similarity: The similarity function to use for the embeddings. One of "euclidean", "cosine" or "dotProduct". "cosine" is the default. + :param embedding_field: The name of the field in the document that contains the embedding. + :param progress_bar: Whether to show a progress bar when writing documents. + :param duplicate_documents: How to handle duplicate documents. One of "overwrite", "skip" or "fail". "overwrite" is the default. + :param recreate_index: Whether to recreate the index when initializing the document store. + """ + mongodb_import.check() + super().__init__() + + self.mongo_connection_string = _validate_mongo_connection_string(mongo_connection_string) + self.database_name = _validate_database_name(database_name) + self.collection_name = _validate_collection_name(collection_name) + self.connection: pymongo.MongoClient = pymongo.MongoClient( + self.mongo_connection_string, driver=DriverInfo(name="Haystack", version=haystack_version) + ) + self.database = self.connection[self.database_name] + self.similarity = _validate_similarity(similarity) + self.duplicate_documents = duplicate_documents + self.embedding_field = embedding_field + self.progress_bar = progress_bar + self.embedding_dim = embedding_dim + self.index = collection_name + self.return_embedding = return_embedding + self.recreate_index = recreate_index + self.vector_search_index = vector_search_index + + if self.recreate_index: + self.delete_index() + + # Implicitly create the collection if it doesn't exist + if collection_name not in self.database.list_collection_names(): + self.database.create_collection(self.collection_name) + self._get_collection().create_index("id", unique=True) + + def _create_document_field_map(self) -> Dict: + return {self.embedding_field: "embedding"} + + def _get_collection(self, index=None) -> "MongoCollection": + """ + Returns the collection named by index or returns the collection specified when the + driver was initialized. + """ + _validate_index_name(index) + if index is not None: + return self.database[index] + else: + return self.database[self.collection_name] + + def delete_documents( + self, + index: Optional[str] = None, + ids: Optional[List[str]] = None, + filters: Optional[FilterType] = None, + headers: Optional[Dict[str, str]] = None, + ): + """ + Delete documents from the document store. + + :param index: Optional collection name. If `None`, the DocumentStore's default collection will be used. + :param ids: Optional list of IDs to narrow down the documents to be deleted. + :param filters: optional filters (see get_all_documents for description). + If filters are provided along with a list of IDs, this method deletes the + intersection of the two query results (documents that match the filters and + have their ID in the list). + :param headers: MongoDBAtlasDocumentStore does not support headers. + :return None: + """ + if headers: + raise NotImplementedError("MongoDBAtlasDocumentStore does not support headers.") + + collection = self._get_collection(index) + + if (ids, filters) == (None, None): + mongo_filters = {} + elif (ids, filters) == (None, filters): + mongo_filters = mongo_filter_converter(filters) + elif (ids, filters) == (ids, None): + mongo_filters = {"id": {"$in": ids}} + elif (ids, filters) == (ids, filters): + mongo_filters = {"$and": [mongo_filter_converter(filters), {"id": {"$in": ids}}]} + + collection.delete_many(filter=mongo_filters) # pylint: disable=possibly-used-before-assignment + + def delete_index(self, index=None): + """ + Deletes the collection named by index or the collection specified when the + driver was initialized. + """ + self._get_collection(index).drop() + + def delete_labels(self): + raise NotImplementedError("MongoDBAtlasDocumentStore does not support labels (yet).") + + def get_all_documents( + self, + index: Optional[str] = None, + filters: Optional[FilterType] = None, + return_embedding: Optional[bool] = False, + batch_size: int = DEFAULT_BATCH_SIZE, + headers: Optional[Dict[str, str]] = None, + ): + """ + Retrieves all documents in the index (collection). + + :param index: Optional collection name. If `None`, the DocumentStore's default collection will be used. + :param filters: Optional filters to narrow down the documents that will be retrieved. + Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical + operator (`"$and"`, `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `"$in"`, `"$gt"`, + `"$gte"`, `"$lt"`, `"$lte"`) or a metadata field name. + Logical operator keys take a dictionary of metadata field names and/or logical operators as + value. Metadata field names take a dictionary of comparison operators as value. Comparison + operator keys take a single value or (in case of `"$in"`) a list of values as value. + If no logical operator is provided, `"$and"` is used as default operation. If no comparison + operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used as default + operation. + __Example__: + + ```python + filters = { + "$and": { + "type": {"$eq": "article"}, + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": {"$in": ["economy", "politics"]}, + "publisher": {"$eq": "nytimes"} + } + } + } + ``` + Note that filters will be acting on the contents of the meta field of the documents in the collection. + :param return_embedding: Optional flag to return the embedding of the document. + :param batch_size: Number of documents to process at a time. When working with large number of documents, + batching can help reduce memory footprint. + :param headers: MongoDBAtlasDocumentStore does not support headers. + """ + if headers: + raise NotImplementedError("MongoDBAtlasDocumentStore does not support headers.") + + result = self.get_all_documents_generator( + index=index, filters=filters, return_embedding=return_embedding, batch_size=batch_size + ) + return list(result) + + def get_all_labels(self): + raise NotImplementedError("MongoDBAtlasDocumentStore does not support labels (yet).") + + def get_document_count( + self, + filters: Optional[FilterType] = None, + index: Optional[str] = None, + only_documents_without_embedding: bool = False, + headers: Optional[Dict[str, str]] = None, + ) -> int: + """ + Return the number of documents. + + :param filters: Optional filters (see get_all_documents for description). + :param index: Optional collection name. If `None`, the DocumentStore's default collection will be used. + :param only_documents_without_embedding: If set to `True`, only documents without embeddings are counted. + :param headers: MongoDBAtlasDocumentStore does not support headers. + """ + if headers: + raise NotImplementedError("MongoDBAtlasDocumentStore does not support headers.") + + collection = self._get_collection(index) + + if only_documents_without_embedding: + mongo_filter = {"$and": [mongo_filter_converter(filters), {"embedding": {"$eq": None}}]} + else: + mongo_filter = mongo_filter_converter(filters) + + return collection.count_documents(mongo_filter) + + def get_embedding_count(self, filters: Optional[FilterType] = None, index: Optional[str] = None) -> int: + """ + Return the number of documents with embeddings. + + :param filters: Optional filters (see get_all_documents for description). + """ + collection = self._get_collection(index) + + filters = filters or {} + + mongo_filters = {"$and": [mongo_filter_converter(filters), {"embedding": {"$ne": None}}]} + + return collection.count_documents(mongo_filters) + + def get_all_documents_generator( + self, + index: Optional[str] = None, + filters: Optional[FilterType] = None, + return_embedding: Optional[bool] = False, + batch_size: int = DEFAULT_BATCH_SIZE, + headers: Optional[Dict[str, str]] = None, + ) -> Generator[Document, None, None]: + """ + Retrieves all documents in the index (collection). Under-the-hood, documents are fetched in batches from the + document store and yielded as individual documents. This method can be used to iteratively process + a large number of documents without having to load all documents in memory. + + :param index: Optional collection name. If `None`, the DocumentStore's default collection will be used. + :param filters: optional filters (see get_all_documents for description). + :param return_embedding: Optional flag to return the embedding of the document. + :param batch_size: Number of documents to process at a time. When working with large number of documents, + batching can help reduce memory footprint. + :param headers: MongoDBAtlasDocumentStore does not support headers. + """ + if headers: + raise NotImplementedError("MongoDBAtlasDocumentStore does not support headers.") + + mongo_filters = mongo_filter_converter(filters) + + if return_embedding is None: + return_embedding = self.return_embedding + + projection = {"embedding": False} if not return_embedding else {} + + collection = self._get_collection(index) + documents = collection.find(mongo_filters, batch_size=batch_size, projection=projection) + + for doc in documents: + yield mongo_doc_to_haystack_doc(doc) + + def get_documents_by_id( + self, + ids: List[str], + index: Optional[str] = None, + batch_size: int = DEFAULT_BATCH_SIZE, + headers: Optional[Dict[str, str]] = None, + return_embedding: Optional[bool] = None, + ) -> List[Document]: + """ + Retrieves all documents matching ids. + + :param ids: List of IDs to retrieve. + :param index: Optional collection name. If `None`, the DocumentStore's default collection will be used. + :param batch_size: Number of documents to retrieve at a time. When working with large number of documents, + batching can help reduce memory footprint. + :param headers: MongoDBAtlasDocumentStore does not support headers. + :param return_embedding: Optional flag to return the embedding of the document. + """ + mongo_filters = {"id": {"$in": ids}} + + result = self.get_all_documents_generator( + index=index, + filters=mongo_filters, # type: ignore [arg-type] + return_embedding=return_embedding, + batch_size=batch_size, + headers=headers, + ) + + return list(result) + + def get_document_by_id( + self, + id: str, + index: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + return_embedding: Optional[bool] = None, + ) -> Document: + """ + Retrieves the document matching id. + + :param id: The ID of the document to retrieve + :param index: Optional collection name. If `None`, the DocumentStore's default collection will be used. + :param headers: MongoDBAtlasDocumentStore does not support headers. + :param return_embedding: Optional flag to return the embedding of the document. + """ + documents = self.get_documents_by_id(ids=[id], index=index, headers=headers, return_embedding=return_embedding) + return documents[0] + + def get_label_count(self): + raise NotImplementedError("MongoDBAtlasDocumentStore does not support labels (yet).") + + def query_by_embedding( + self, + query_emb: np.ndarray, + filters: Optional[FilterType] = None, + top_k: int = 10, + index: Optional[str] = None, + return_embedding: Optional[bool] = None, + headers: Optional[Dict[str, str]] = None, + scale_score: bool = True, + ) -> List[Document]: + """ + Find the documents that are most similar to the provided `query_emb` by using a vector similarity metric. + + :param query_emb: Embedding of the query + :param filters: optional filters (see get_all_documents for description). + :param top_k: How many documents to return. + :param index: Optional collection name. If `None`, the DocumentStore's default collection will be used. + :param return_embedding: Whether to return document embedding. + :param headers: MongoDBAtlasDocumentStore does not support headers. + :param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]). + If true (default) similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant. + Otherwise raw similarity scores (e.g. cosine or dot_product) will be used. + """ + if not self.vector_search_index: + raise ValueError( + "No vector_search_index is set for MongoDBAtlasDocumentStore. Create a vector_search_index in the Atlas web UI and specify it in the init parameters of MongoDBAtlasDocumentStore. https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#std-label-avs-create-index" + ) + if headers: + raise NotImplementedError("MongoDBAtlasDocumentStore does not support headers.") + + if return_embedding is None: + return_embedding = self.return_embedding + + collection = self._get_collection(index) + + query_emb = query_emb.astype(np.float32) + + if self.similarity == "cosine": + self.normalize_embedding(query_emb) + + filters = filters or {} + + pipeline = [ + { + "$vectorSearch": { + "index": self.vector_search_index, + "queryVector": query_emb.tolist(), + "path": "embedding", + "numCandidates": 100, + "limit": top_k, + } + } + ] + if filters is not None: + pipeline.append({"$match": mongo_filter_converter(filters)}) + if not return_embedding: + pipeline.append({"$project": {"embedding": False}}) + pipeline.append({"$set": {"score": {"$meta": "vectorSearchScore"}}}) + documents = list(collection.aggregate(pipeline)) + + if scale_score: + for doc in documents: + doc["score"] = self.scale_to_unit_interval(doc["score"], self.similarity) + + documents = [mongo_doc_to_haystack_doc(doc) for doc in documents] + return documents + + def update_document_meta(self, id: str, meta: Dict[str, str], index: Optional[str] = None): + """ + Update the metadata dictionary of a document by specifying its string ID. + + :param id: ID of the Document to update. + :param meta: Dictionary of new metadata. + :param index: Optional collection name. If `None`, the DocumentStore's default collection will be used. + """ + collection = self._get_collection(index) + collection.update_one({"id": id}, {"$set": {"meta": meta}}) + + def write_documents( + self, + documents: Union[List[dict], List[Document]], + index: Optional[str] = None, + batch_size: int = DEFAULT_BATCH_SIZE, + duplicate_documents: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + ): + """ + Parameters: + + documents: List of `Dicts` or `Documents` + index (str): search index name - contain letters, numbers, hyphens, or underscores + :param duplicate_documents: handle duplicate documents based on parameter options. + Parameter options: + - `"overwrite"`: Update any existing documents with the same ID when adding documents. + - `"skip"`: Ignore the duplicate documents. + - `"fail"`: An error is raised if the document ID of the document being added already exists. + + "overwrite" is the default behaviour. + """ + if headers: + raise NotImplementedError("MongoDBAtlasDocumentStore does not support headers.") + + collection = self._get_collection(index) + + duplicate_documents = duplicate_documents or self.duplicate_documents + + field_map = self._create_document_field_map() + documents = [ + Document.from_dict(doc, field_map=field_map) if isinstance(doc, dict) else doc for doc in documents + ] + + mongo_documents = list(map(Document.to_dict, documents)) + + with tqdm( + total=len(mongo_documents), + disable=not self.progress_bar, + position=0, + unit=" docs", + desc="Writing Documents", + ) as progress_bar: + batches = get_batches_from_generator(mongo_documents, batch_size) + for batch in batches: + operations: List[Union[UpdateOne, InsertOne, ReplaceOne]] + if duplicate_documents == "skip": + operations = [UpdateOne({"id": doc["id"]}, {"$setOnInsert": doc}, upsert=True) for doc in batch] + elif duplicate_documents == "fail": + operations = [InsertOne(doc) for doc in batch] + else: + operations = [ReplaceOne({"id": doc["id"]}, upsert=True, replacement=doc) for doc in batch] + + collection.bulk_write(operations) + progress_bar.update(len(batch)) + + def write_labels(self): + raise NotImplementedError("MongoDBAtlasDocumentStore does not support labels (yet).") + + def update_embeddings( + self, + retriever: DenseRetriever, + index: Optional[str] = None, + update_existing_embeddings: bool = True, + filters: Optional[FilterType] = None, + batch_size: int = DEFAULT_BATCH_SIZE, + ): + """ + Updates the embeddings in the document store using the encoding model specified in the retriever. + + This can be useful if you want to add or change the embeddings for your documents (e.g. after changing the + retriever config). + + :param retriever: Retriever to use to get embeddings for text. + :param index: Optional collection name. If `None`, the DocumentStore's default collection will be used. + :param update_existing_embeddings: Whether to update existing embeddings of the documents. If set to `False`, + only documents without embeddings are processed. This mode can be used for incremental updating of + embeddings, wherein, only newly indexed documents get processed. + :param filters: optional filters (see get_all_documents for description). + :param batch_size: Number of documents to process at a time. When working with large number of documents, + batching can help reduce memory footprint. " + """ + filters = filters or {} + document_count = self.get_document_count( + index=index, filters=filters, only_documents_without_embedding=not update_existing_embeddings + ) + + if not update_existing_embeddings: + filters = {"$and": [filters, {"embedding": {"$eq": None}}]} + + documents = self.get_all_documents_generator( + index=index, filters=filters, return_embedding=False, batch_size=batch_size + ) + + collection = self._get_collection(index) + + with tqdm( + total=document_count, disable=not self.progress_bar, unit=" docs", desc="Updating Embeddings" + ) as progress_bar: + batches = get_batches_from_generator(documents, batch_size) + for batch in batches: + embeddings = retriever.embed_documents(batch) + self._validate_embeddings_shape( + embeddings=embeddings, num_documents=len(batch), embedding_dim=self.embedding_dim + ) + if self.similarity == "cosine": + self.normalize_embedding(embeddings) + + mongo_documents = [haystack_doc_to_mongo_doc(doc) for doc in batch] + + for doc, embedding in zip(mongo_documents, embeddings.tolist()): + doc["embedding"] = embedding + + updates = [ReplaceOne({"id": doc["id"]}, doc) for doc in mongo_documents] + collection.bulk_write(updates) + progress_bar.update(len(batch)) + + +class MongoDBAtlasDocumentStoreError(DocumentStoreError): + """Exception for issues that occur in a MongoDBAtlas document store""" + + def __init__(self, message: Optional[str] = None): + super().__init__(message=message) + + +class ValidationError(Exception): + """Exception for validation errors""" + + pass + + +def _validate_mongo_connection_string(mongo_connection_string): + if not mongo_connection_string: + raise MongoDBAtlasDocumentStoreError( + "A `mongodb_connection_string` is required. This can be obtained on the MongoDB Atlas Dashboard by clicking on the `CONNECT` button." + ) + return mongo_connection_string + + +def _validate_database_name(database_name): + # There doesn't seem to be much restriction on the name here? All sorts of special character are apparently allowed... + # Just check if it's there. + if not database_name: + raise ValidationError("A `database_name` is required.") + return database_name + + +def _validate_collection_name(collection_name): + # There doesn't seem to be much restriction on the name here? All sorts of special character are apparently allowed... + # Just check if it's there. + if not collection_name: + raise ValidationError("A `collection_name` is required.") + return collection_name + + +def _validate_similarity(similarity): + if similarity not in METRIC_TYPES: + raise ValueError( + "MongoDB Atlas currently supports dotProduct, cosine and euclidean metrics. Please set similarity to one of the above." + ) + return similarity + + +def _validate_index_name(index_name): + if index_name and not bool(re.match(r"^[a-zA-Z0-9\-_]+$", index_name)): + raise ValueError( + f'Invalid index name: "{index_name}". Index name can only contain letters, numbers, hyphens, or underscores.' + ) + return index_name + + +def mongo_doc_to_haystack_doc(mongo_doc) -> Document: + embedding = mongo_doc.get("embedding", None) + score = mongo_doc.get("score") + + return Document( + id=mongo_doc["id"], + content=mongo_doc["content"], + content_type=mongo_doc["content_type"], + meta=mongo_doc["meta"], + embedding=embedding, + score=score, + ) + + +def haystack_doc_to_mongo_doc(haystack_doc) -> Dict: + return { + "id": haystack_doc.id, + "content": haystack_doc.content, + "content_type": haystack_doc.content_type, + "meta": haystack_doc.meta, + } diff --git a/haystack/document_stores/mongodb_filters.py b/haystack/document_stores/mongodb_filters.py new file mode 100644 index 0000000000..691ffcc8fb --- /dev/null +++ b/haystack/document_stores/mongodb_filters.py @@ -0,0 +1,91 @@ +from typing import Union, Any, Dict + +FILTER_OPERATORS = ["$and", "$or", "$not", "$eq", "$in", "$gt", "$gte", "$lt", "$lte"] +EXCLUDE_FROM_METADATA_PREPEND = ["id", "embedding"] + +METADATA_FIELD_NAME = "meta" + + +def mongo_filter_converter(filter) -> Dict[str, Any]: + if not filter: + return {} + else: + filter = _target_filter_to_metadata(filter, METADATA_FIELD_NAME) + filter = _and_or_to_list(filter) + return filter + + +def _target_filter_to_metadata(filter, metadata_field_name) -> Union[Dict[str, Any], list]: + """ + Returns a new filter with any non-operator, non-excluded keys renamed so that the metadata + field name is prepended. Does not mutate input filter. + + Example: + + { + "$and": { + "url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes", + "_split_id": 0 + } + } + + will be replaced with: + + { + "$and": { + "meta.url": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes", + "meta._split_id": 0 + } + } + + """ + if isinstance(filter, dict): + updated_dict = {} + for key, value in filter.items(): + if key not in FILTER_OPERATORS + EXCLUDE_FROM_METADATA_PREPEND: + key = f"{metadata_field_name}.{key}" + + if isinstance(value, (dict, list)): + updated_dict[key] = _target_filter_to_metadata(value, metadata_field_name) + else: + updated_dict[key] = value + return updated_dict + elif isinstance(filter, list): + return [_target_filter_to_metadata(item, metadata_field_name) for item in filter] + return filter + + +def _and_or_to_list(filter) -> Union[Dict[str, Any], list]: + """ + Returns a new filter replacing any dict values associated with "$and" or "$or" keys + with a list. Does not mutate input filter. + + Example: + + { + "$and": { + "url": {"$eq": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes"}, + "_split_id": {"$eq": 0}, + }, + } + + will be replaced with: + + { + "$and": [ + {"url": {"$eq": "https://en.wikipedia.org/wiki/Colossus_of_Rhodes"}}, + {"_split_id": {"$eq": 0}}, + ] + } + """ + if isinstance(filter, dict): + updated_dict = filter.copy() + if "$and" in updated_dict and isinstance(filter["$and"], dict): + updated_dict["$and"] = [{key: value} for key, value in filter["$and"].items()] + if "$or" in updated_dict and isinstance(filter["$or"], dict): + updated_dict["$or"] = [{key: value} for key, value in filter["$or"].items()] + return {key: _and_or_to_list(value) for key, value in updated_dict.items()} + elif isinstance(filter, list): + return [_and_or_to_list(item) for item in filter] + else: + return filter diff --git a/haystack/document_stores/opensearch.py b/haystack/document_stores/opensearch.py index 0d6e776660..8194b5aa40 100644 --- a/haystack/document_stores/opensearch.py +++ b/haystack/document_stores/opensearch.py @@ -1521,6 +1521,7 @@ def get_metadata_values_by_key( filters: Optional[FilterType] = None, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None, + batch_size: int = 10, ) -> List[dict]: """ Get values associated with a metadata key. The output is in the format: @@ -1558,10 +1559,16 @@ def get_metadata_values_by_key( self.index is used. :param headers: Custom HTTP headers to pass to the client (for example, {'Authorization': 'Basic YWRtaW46cm9vdA=='}) Check out [Elasticsearch documentation](https://www.elastic.co/guide/en/elasticsearch/reference/current/http-clients.html) for more information. + :param batch_size: Maximum number of results for each request. + Limited to 10 values by default. You can increase this limit to decrease retrieval time. + To reduce the pressure on the cluster, you shouldn't set this higher than 1,000. """ + index = index or self.index body: dict = { "size": 0, - "aggs": {"metadata_agg": {"composite": {"sources": [{key: {"terms": {"field": key}}}]}}}, + "aggs": { + "metadata_agg": {"composite": {"sources": [{key: {"terms": {"field": key}}}], "size": batch_size}} + }, } if query: body["query"] = { diff --git a/haystack/document_stores/pinecone.py b/haystack/document_stores/pinecone.py index 33da45f1f2..20403b32ae 100644 --- a/haystack/document_stores/pinecone.py +++ b/haystack/document_stores/pinecone.py @@ -6,7 +6,7 @@ from datetime import datetime from functools import reduce from itertools import islice -from typing import Any, Dict, Generator, List, Literal, Optional, Set, Union +from typing import Any, Dict, Generator, List, Literal, Optional, Set, Tuple, Union import numpy as np from tqdm import tqdm @@ -33,7 +33,9 @@ IN_OPERATOR = "$in" EQ_OPERATOR = "$eq" -DEFAULT_BATCH_SIZE = 128 +DEFAULT_BATCH_SIZE = 64 +DEFAULT_POOL_THREADS = 1 +DEFAULT_DOCUMENT_CHUNK_SIZE = 1000 PINECONE_STARTER_POD = "starter" @@ -79,6 +81,8 @@ def __init__( environment: str = "us-west1-gcp", pinecone_index: Optional["pinecone.Index"] = None, embedding_dim: int = 768, + pods: int = 1, + pod_type: str = "p1.x1", return_embedding: bool = False, index: str = "document", similarity: str = "cosine", @@ -91,6 +95,7 @@ def __init__( recreate_index: bool = False, metadata_config: Optional[Dict] = None, validate_index_sync: bool = True, + pool_threads: int = DEFAULT_POOL_THREADS, ): """ :param api_key: Pinecone vector database API key ([https://app.pinecone.io](https://app.pinecone.io)). @@ -98,6 +103,8 @@ def __init__( regions are supported, contact Pinecone [here](https://www.pinecone.io/contact/) if required. :param pinecone_index: pinecone-client Index object, an index will be initialized or loaded if not specified. :param embedding_dim: The embedding vector size. + :param pods: The number of pods for the index to use, including replicas. Defaults to 1. + :param pod_type: The type of pod to use. Defaults to `"p1.x1"`. :param return_embedding: Whether to return document embeddings. :param index: Name of index in document store to use. :param similarity: The similarity function used to compare document vectors. `"cosine"` is the default @@ -126,6 +133,7 @@ def __init__( [selective metadata filtering](https://www.pinecone.io/docs/manage-indexes/#selective-metadata-indexing) feature. Should be in the format `{"indexed": ["metadata-field-1", "metadata-field-2", "metadata-field-n"]}`. By default, no fields are indexed. + :param pool_threads: Number of threads to use for index upsert. """ pinecone_import.check() if metadata_config is None: @@ -151,6 +159,8 @@ def __init__( self.duplicate_documents = duplicate_documents # Pinecone index params + self.pods = pods + self.pod_type = pod_type self.replicas = replicas self.shards = shards self.namespace = namespace @@ -168,8 +178,8 @@ def __init__( # Initialize dictionary to store temporary set of document IDs self.all_ids: dict = {} - # Dummy query to be used during searches - self.dummy_query = [0.0] * self.embedding_dim + # Dummy vector to be used during searches and as a placeholder for documents without embeddings + self.dummy_vector = [-10.0] * self.embedding_dim if pinecone_index: if not isinstance(pinecone_index, pinecone.Index): @@ -182,12 +192,15 @@ def __init__( else: self.pinecone_indexes[self.index] = self._create_index( embedding_dim=self.embedding_dim, + pods=self.pods, + pod_type=self.pod_type, index=self.index, metric_type=self.metric_type, replicas=self.replicas, shards=self.shards, recreate_index=recreate_index, metadata_config=self.metadata_config, + pool_threads=pool_threads, ) super().__init__() @@ -199,12 +212,15 @@ def _index(self, index) -> str: def _create_index( self, embedding_dim: int, + pods: int = 1, + pod_type: str = "p1.x1", index: Optional[str] = None, metric_type: Optional[str] = "cosine", replicas: Optional[int] = 1, shards: Optional[int] = 1, recreate_index: bool = False, metadata_config: Optional[Dict] = None, + pool_threads: int = DEFAULT_POOL_THREADS, ) -> "pinecone.Index": """ Create a new index for storing documents in case an index with the name @@ -225,12 +241,14 @@ def _create_index( pinecone.create_index( name=index, dimension=embedding_dim, + pods=pods, + pod_type=pod_type, metric=metric_type, replicas=replicas, shards=shards, metadata_config=metadata_config, ) - index_connection = pinecone.Index(index) + index_connection = pinecone.Index(index, pool_threads) # Get index statistics stats = index_connection.describe_index_stats() @@ -254,6 +272,8 @@ def _index_connection_exists(self, index: str, create: bool = False) -> Optional if create: return self._create_index( embedding_dim=self.embedding_dim, + pods=self.pods, + pod_type=self.pod_type, index=index, metric_type=self.metric_type, replicas=self.replicas, @@ -364,9 +384,9 @@ def _get_vector_count( return namespaces[namespace]["vector_count"] if namespace in namespaces else 0 # Due to missing support for metadata filtering in `describe_index_stats()` method for `gcp-starter`, - # use dummy query for getting vector count + # use dummy query vector for getting vector count res = self.pinecone_indexes[index].query( - self.dummy_query, + self.dummy_vector, top_k=self.top_k_limit, include_values=False, include_metadata=False, @@ -386,6 +406,23 @@ def _delete_vectors(self, index: str, ids: List[str], namespace: Optional[str]) for id_batch in get_batches_from_generator(ids, batch_size): self.pinecone_indexes[index].delete(ids=list(id_batch), namespace=namespace) + def _upsert_vectors( + self, + index_name: str, + data: List[Tuple], + namespace: Optional[str], + use_async: bool = False, + batch_size: int = DEFAULT_BATCH_SIZE, + ) -> None: + index = self.pinecone_indexes[index_name] + results = [ + index.upsert(vectors=batch, namespace=namespace, async_req=use_async) + for batch in get_batches_from_generator(data, batch_size) + ] + if use_async: + for res in results: + res.get() + def get_document_count( self, filters: Optional[FilterType] = None, @@ -517,6 +554,8 @@ def write_documents( headers: Optional[Dict[str, str]] = None, labels: Optional[bool] = False, namespace: Optional[str] = None, + use_async: bool = False, + document_chunk_size: int = DEFAULT_DOCUMENT_CHUNK_SIZE, ): """ Add new documents to the DocumentStore. @@ -524,7 +563,7 @@ def write_documents( :param documents: List of `Dicts` or list of `Documents`. If they already contain embeddings, we'll index them right away in Pinecone. If not, you can later call `update_embeddings()` to create & index them. :param index: Index name for storing the docs and metadata. - :param batch_size: Number of documents to process at a time. When working with large number of documents, + :param batch_size: Number of documents to upsert at a time. When working with large number of documents, batching can help to reduce the memory footprint. :param duplicate_documents: handle duplicate documents based on parameter options. Parameter options: @@ -534,6 +573,9 @@ def write_documents( :param headers: PineconeDocumentStore does not support headers. :param labels: Tells us whether these records are labels or not. Defaults to False. :param namespace: Optional namespace to write documents to. If not specified, None is default. + :param use_async: If set to True, Pinecone index will upsert documents in parallel. + :param document_chunk_size: Number of documents to process at a time. If use_async is set to True, + along with batch_size will speed up document upsert by doing it in parallel. :raises DuplicateDocumentError: Exception trigger on duplicate document. """ if headers: @@ -549,6 +591,20 @@ def write_documents( if index_connection: self.pinecone_indexes[index] = index_connection + pool_threads = self.pinecone_indexes[index].pool_threads + if use_async and pool_threads == 1: + logger.warning( + "Documents will be upserted synchronosly, because the number of threads for Pinecone index is set to %s. " + "To enable upsert in parallel, initialize PineconeDocumentStore() again setting parameter `pool_threads`.", + pool_threads, + ) + elif not use_async and pool_threads != 1: + logger.warning( + "Parameter `use_async` set to `False` will be ignored and documents will be upserted asynchronously, " + "because the number of threads for Pinecone index is set to %s.", + pool_threads, + ) + field_map = self._create_document_field_map() document_objects = [ Document.from_dict(doc, field_map=field_map) if isinstance(doc, dict) else doc for doc in documents @@ -556,6 +612,9 @@ def write_documents( document_objects = self._handle_duplicate_documents( documents=document_objects, index=index, duplicate_documents=duplicate_documents ) + + # set chunk size to document_chunk_size for async upsert or batch_size otherwise (regular upsert) + chunk_size = document_chunk_size if use_async else batch_size if document_objects: add_vectors = document_objects[0].embedding is not None # If these are not labels, we need to find the correct value for `doc_type` metadata field @@ -563,53 +622,47 @@ def write_documents( type_metadata = DOCUMENT_WITH_EMBEDDING if add_vectors else DOCUMENT_WITHOUT_EMBEDDING else: type_metadata = LABEL - if not add_vectors: - # To store documents in Pinecone, we use dummy embeddings (to be replaced with real embeddings later) - embeddings_to_index = np.zeros((batch_size, self.embedding_dim), dtype="float32") - # Convert embeddings to list objects - embeddings = [embed.tolist() if embed is not None else None for embed in embeddings_to_index] with tqdm( total=len(document_objects), disable=not self.progress_bar, position=0, desc="Writing Documents" ) as progress_bar: - for document_batch in get_batches_from_generator(document_objects, batch_size): - document_batch = list(document_batch) - document_batch_copy = deepcopy(document_batch) - ids = [doc.id for doc in document_batch] + for document_chunk in get_batches_from_generator(document_objects, chunk_size): + document_chunk = list(document_chunk) + ids = [doc.id for doc in document_chunk] # If duplicate_documents set to `skip` or `fail`, we need to check for existing documents if duplicate_documents in ["skip", "fail"]: existing_documents = self.get_documents_by_id( - ids=ids, index=index, namespace=namespace, include_type_metadata=True + ids=ids, index=index, namespace=namespace, include_type_metadata=True, batch_size=chunk_size ) - # First check for documents in current batch that exist in the index + # First check for documents in current chunk that exist in the index if existing_documents: if duplicate_documents == "skip": # If we should skip existing documents, we drop the ids that already exist skip_ids = [doc.id for doc in existing_documents] - # We need to drop the affected document objects from the batch - document_batch = [doc for doc in document_batch if doc.id not in skip_ids] + # We need to drop the affected document objects from the chunk + document_chunk = [doc for doc in document_chunk if doc.id not in skip_ids] # Now rebuild the ID list - ids = [doc.id for doc in document_batch] + ids = [doc.id for doc in document_chunk] progress_bar.update(len(skip_ids)) elif duplicate_documents == "fail": # Otherwise, we raise an error raise DuplicateDocumentError( f"Document ID {existing_documents[0].id} already exists in index {index}" ) - # Now check for duplicate documents within the batch itself + # Now check for duplicate documents within the chunk itself if len(ids) != len(set(ids)): if duplicate_documents == "skip": # We just keep the first instance of each duplicate document ids = [] - temp_document_batch = [] - for doc in document_batch: + temp_document_chunk = [] + for doc in document_chunk: if doc.id not in ids: ids.append(doc.id) - temp_document_batch.append(doc) - document_batch = temp_document_batch + temp_document_chunk.append(doc) + document_chunk = temp_document_chunk elif duplicate_documents == "fail": # Otherwise, we raise an error - raise DuplicateDocumentError(f"Duplicate document IDs found in batch: {ids}") + raise DuplicateDocumentError(f"Duplicate document IDs found in chunk: {ids}") metadata = [ self._meta_for_pinecone( { @@ -619,22 +672,26 @@ def write_documents( **doc.meta, } ) - for doc in document_batch_copy + for doc in document_chunk ] if add_vectors: - embeddings = [doc.embedding for doc in document_batch_copy] + embeddings = [doc.embedding for doc in document_chunk] embeddings_to_index = np.array(embeddings, dtype="float32") if self.similarity == "cosine": # Normalize embeddings inplace self.normalize_embedding(embeddings_to_index) # Convert embeddings to list objects embeddings = [embed.tolist() if embed is not None else None for embed in embeddings_to_index] - data_to_write_to_pinecone = zip(ids, embeddings, metadata) - # Metadata fields and embeddings are stored in Pinecone - self.pinecone_indexes[index].upsert(vectors=data_to_write_to_pinecone, namespace=namespace) + else: + # Use dummy embeddings for all documents + embeddings = [self.dummy_vector] * len(document_chunk) + + data_to_write_to_pinecone = list(zip(ids, embeddings, metadata)) + # Store chunk by chunk (for regular upsert) or chunk by chunk (for async upsert) in vector store + self._upsert_vectors(index, data_to_write_to_pinecone, namespace, use_async, batch_size) # type: ignore # Add IDs to ID list self._add_local_ids(index, ids) - progress_bar.update(batch_size) + progress_bar.update(chunk_size) progress_bar.close() def _create_document_field_map(self) -> Dict: @@ -648,6 +705,8 @@ def update_embeddings( filters: Optional[FilterType] = None, batch_size: int = DEFAULT_BATCH_SIZE, namespace: Optional[str] = None, + use_async: bool = False, + document_chunk_size: int = DEFAULT_DOCUMENT_CHUNK_SIZE, ): """ Updates the embeddings in the document store using the encoding model specified in the retriever. @@ -688,6 +747,9 @@ def update_embeddings( :param batch_size: Number of documents to process at a time. When working with large number of documents, batching can help reduce memory footprint. :param namespace: Optional namespace to retrieve document from. If not specified, None is default. + :param use_async: If set to True, Pinecone index will update embeddings in parallel. + :param document_chunk_size: Number of documents to process at a time. If use_async is set to True, + along with batch_size will speed up updating the embeddings by doing it in parallel. """ index = self._index(index) if index not in self.pinecone_indexes: @@ -695,6 +757,21 @@ def update_embeddings( f"Couldn't find a the index '{index}' in Pinecone. Try to init the " f"PineconeDocumentStore() again ..." ) + + pool_threads = self.pinecone_indexes[index].pool_threads + if use_async and pool_threads == 1: + logger.warning( + "Embeddings will be upserted synchronosly, because the number of threads for Pinecone index is %s. " + "To enable upsert in parallel, initialize PineconeDocumentStore() again setting parameter `pool_threads`.", + pool_threads, + ) + elif not use_async and pool_threads > 1: + logger.warning( + "Parameter `use_async` set to `False` will be ignored and embeddings will be upserted asynchronously, " + "because the number of threads for Pinecone index is set to %s.", + pool_threads, + ) + document_count = self.get_document_count( index=index, filters=filters, @@ -723,20 +800,22 @@ def update_embeddings( include_type_metadata=True, ) + chunk_size = document_chunk_size if use_async else batch_size with tqdm( total=document_count, disable=not self.progress_bar, position=0, unit=" docs", desc="Updating Embedding" ) as progress_bar: - for _ in range(0, document_count, batch_size): - document_batch = list(islice(documents, batch_size)) - embeddings = retriever.embed_documents(document_batch) + for _ in range(0, document_count, chunk_size): + document_chunk = list(islice(documents, chunk_size)) + document_chunk_size = len(document_chunk) + embeddings = retriever.embed_documents(document_chunk) if embeddings.size == 0: - # Skip batch if there are no embeddings. Otherwise, incorrect embedding shape will be inferred and + # Skip chunk if there are no embeddings. Otherwise, incorrect embedding shape will be inferred and # Pinecone APi will return a "No vectors provided" Bad Request Error progress_bar.set_description_str("Documents Processed") progress_bar.update(batch_size) continue self._validate_embeddings_shape( - embeddings=embeddings, num_documents=len(document_batch), embedding_dim=self.embedding_dim + embeddings=embeddings, num_documents=document_chunk_size, embedding_dim=self.embedding_dim ) if self.similarity == "cosine": @@ -744,7 +823,7 @@ def update_embeddings( metadata = [] ids = [] - for doc in document_batch: + for doc in document_chunk: metadata.append( self._meta_for_pinecone( { @@ -758,13 +837,12 @@ def update_embeddings( ) ids.append(doc.id) # Update existing vectors in pinecone index - self.pinecone_indexes[index].upsert( - vectors=zip(ids, embeddings.tolist(), metadata), namespace=namespace - ) + data = list(zip(ids, embeddings.tolist(), metadata)) + self._upsert_vectors(index, data, namespace, use_async, batch_size) # type: ignore # Add these vector IDs to local store self._add_local_ids(index, ids) progress_bar.set_description_str("Documents Processed") - progress_bar.update(batch_size) + progress_bar.update(document_chunk_size) def get_all_documents( self, @@ -1020,7 +1098,7 @@ def _move_documents_by_id_namespace( embedding_matrix = [result["vectors"][_id]["values"] for _id in vector_id_matrix] data_to_write_to_pinecone = list(zip(vector_id_matrix, embedding_matrix, meta_matrix)) # Store metadata nd embeddings in new target_namespace - self.pinecone_indexes[index].upsert(vectors=data_to_write_to_pinecone, namespace=target_namespace) + self._upsert_vectors(index, data_to_write_to_pinecone, target_namespace, use_async=False) # type: ignore # Delete vectors from source_namespace self.delete_documents(index=index, ids=id_batch, namespace=source_namespace, drop_ids=False) progress_bar.set_description_str("Documents Moved") @@ -1145,7 +1223,8 @@ def update_document_meta(self, id: str, meta: Dict[str, str], index: Optional[st if doc.embedding is not None: meta = {"content": doc.content, "content_type": doc.content_type, **meta} - self.pinecone_indexes[index].upsert(vectors=[(id, doc.embedding.tolist(), meta)], namespace=self.namespace) + data = [(id, doc.embedding.tolist(), meta)] + self._upsert_vectors(index, data, self.namespace, use_async=False) # type: ignore def delete_documents( self, @@ -1501,7 +1580,7 @@ def _get_ids( # Retrieve embeddings from Pinecone try: res = self.pinecone_indexes[index].query( - self.dummy_query, + self.dummy_vector, top_k=batch_size, include_values=False, include_metadata=False, @@ -1749,7 +1828,7 @@ def delete_labels( self._index_connection_exists(index) i = 0 - dummy_query = np.asarray(self.dummy_query) + dummy_query = np.asarray(self.dummy_vector) type_metadata = LABEL diff --git a/haystack/document_stores/search_engine.py b/haystack/document_stores/search_engine.py index eac2a6ec7a..8f5a149e31 100644 --- a/haystack/document_stores/search_engine.py +++ b/haystack/document_stores/search_engine.py @@ -286,6 +286,7 @@ def get_metadata_values_by_key( filters: Optional[FilterType] = None, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None, + batch_size: int = 10, ) -> List[dict]: """ Get values associated with a metadata key. The output is in the format: @@ -323,10 +324,16 @@ def get_metadata_values_by_key( self.index will be used. :param headers: Custom HTTP headers to pass to the client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='}) Check out https://www.elastic.co/guide/en/elasticsearch/reference/current/http-clients.html for more information. + :param batch_size: Maximum number of results for each request. + Limited to 10 values by default. You can increase this limit to decrease retrieval time. + To reduce the pressure on the cluster, you shouldn't set this higher than 1,000. """ + index = index or self.index body: dict = { "size": 0, - "aggs": {"metadata_agg": {"composite": {"sources": [{key: {"terms": {"field": key}}}]}}}, + "aggs": { + "metadata_agg": {"composite": {"sources": [{key: {"terms": {"field": key}}}], "size": batch_size}} + }, } if query: body["query"] = { diff --git a/haystack/document_stores/weaviate.py b/haystack/document_stores/weaviate.py index 5484ca16c6..5143c330f7 100644 --- a/haystack/document_stores/weaviate.py +++ b/haystack/document_stores/weaviate.py @@ -104,7 +104,7 @@ def __init__( :param host: Weaviate server connection URL for storing and processing documents and vectors. For more details, see [Weaviate installation](https://weaviate.io/developers/weaviate/current/getting-started/installation.html). :param port: The port of the Weaviate instance. - :param timeout_config: The Weaviate timeout config as a tuple of (retries, time out seconds). + :param timeout_config: The Weaviate timeout config as a tuple of (connect timeout, read timeout). :param username: The Weaviate username (standard authentication using http_auth). :param password: Weaviate password (standard authentication using http_auth). :param scope: The scope of the credentials when using the OIDC Resource Owner Password or Client Credentials authentication flow. @@ -684,8 +684,10 @@ def write_documents( property_value = _doc[property] if property in json_fields: property_value = doc.meta[property] - self._update_schema(property, property_value, index) - current_properties.append(property) + # if the property_value is an empty list, we can't infer the type + if not isinstance(property_value, list) or len(property_value) > 0: + self._update_schema(property, property_value, index) + current_properties.append(property) # update the date fields as there might be new ones date_fields = self._get_date_properties(index) diff --git a/haystack/modeling/model/language_model.py b/haystack/modeling/model/language_model.py index e1f4028eba..abc15a5547 100644 --- a/haystack/modeling/model/language_model.py +++ b/haystack/modeling/model/language_model.py @@ -18,26 +18,24 @@ Thanks for the great work! """ -from typing import Type, Optional, Dict, Any, Union, List - -import re import json import logging import os +import re from abc import ABC, abstractmethod from pathlib import Path +from typing import Any, Dict, List, Optional, Type, Union + import numpy as np import torch -from torch import nn import transformers -from transformers import PretrainedConfig, PreTrainedModel -from transformers import AutoModel, AutoConfig +from torch import nn +from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel from transformers.modeling_utils import SequenceSummary from haystack.errors import ModelingError from haystack.modeling.utils import silence_transformers_logs - logger = logging.getLogger(__name__) @@ -213,8 +211,7 @@ def _pool_tokens( ): token_vecs = sequence_output.cpu().numpy() # we only take the aggregated value of non-padding tokens - padding_mask = padding_mask.cpu().numpy() - ignore_mask_2d = padding_mask == 0 + ignore_mask_2d = padding_mask.cpu().numpy() == 0 # sometimes we want to exclude the CLS token as well from our aggregation operation if ignore_first_token: ignore_mask_2d[:, 0] = True @@ -225,7 +222,7 @@ def _pool_tokens( if strategy == "reduce_mean": pooled_vecs = np.ma.array(data=token_vecs, mask=ignore_mask_3d).mean(axis=1).data - return pooled_vecs + return pooled_vecs # pylint: disable=possibly-used-before-assignment class HFLanguageModel(LanguageModel): diff --git a/haystack/modeling/model/prediction_head.py b/haystack/modeling/model/prediction_head.py index 45e2d72690..74136ddf23 100644 --- a/haystack/modeling/model/prediction_head.py +++ b/haystack/modeling/model/prediction_head.py @@ -1,3 +1,5 @@ +# pylint: skip-file + import json import logging import os @@ -336,6 +338,8 @@ def load( # type: ignore head = cls(layer_dims=[full_qa_model.config.hidden_size, 2], task_name="question_answering") # transfer weights for head from full model head.feed_forward.feed_forward[0].load_state_dict(full_qa_model.qa_outputs.state_dict()) + # Set the last feed_forward layer to the correct torch dtype + head.feed_forward.feed_forward[0].to(full_qa_model.qa_outputs.weight.dtype) del full_qa_model return head @@ -498,15 +502,14 @@ def logits_to_preds( # sorted_candidates.shape : (batch_size, max_seq_len^2, 2) start_indices = torch.div(flat_sorted_indices, max_seq_len, rounding_mode="trunc") end_indices = flat_sorted_indices % max_seq_len - sorted_candidates = torch.cat((start_indices, end_indices), dim=2) # Get the n_best candidate answers for each sample - sorted_candidates = sorted_candidates.cpu().numpy() - start_end_matrix = start_end_matrix.cpu().numpy() + sorted_candidates = torch.cat((start_indices, end_indices), dim=2).cpu().numpy() + start_end_matrix_array = start_end_matrix.cpu().numpy() for sample_idx in range(batch_size): sample_top_n = self.get_top_candidates( sorted_candidates[sample_idx], - start_end_matrix[sample_idx], + start_end_matrix_array[sample_idx], sample_idx, start_matrix=start_matrix[sample_idx], end_matrix=end_matrix[sample_idx], @@ -971,7 +974,9 @@ def cosine_scores(cls, query_vectors: torch.Tensor, passage_vectors: torch.Tenso passages_per_batch = passage_vectors.shape[0] for query_vector in query_vectors: query_vector_repeated = query_vector.repeat(passages_per_batch, 1) - current_cosine_similarities = nn.functional.cosine_similarity(query_vector_repeated, passage_vectors, dim=1) + current_cosine_similarities = nn.functional.cosine_similarity( # pylint: disable=not-callable + query_vector_repeated, passage_vectors, dim=1 + ) cosine_similarities.append(current_cosine_similarities) return torch.stack(cosine_similarities) diff --git a/haystack/nodes/__init__.py b/haystack/nodes/__init__.py index 9ac1f3268b..9ea7c18187 100644 --- a/haystack/nodes/__init__.py +++ b/haystack/nodes/__init__.py @@ -1,6 +1,5 @@ from haystack.nodes.base import BaseComponent -from haystack.nodes.answer_generator import BaseGenerator, OpenAIAnswerGenerator from haystack.nodes.document_classifier import BaseDocumentClassifier, TransformersDocumentClassifier from haystack.nodes.extractor import EntityExtractor, simplify_ner_for_qa from haystack.nodes.file_classifier import FileTypeClassifier diff --git a/haystack/nodes/answer_generator/__init__.py b/haystack/nodes/answer_generator/__init__.py deleted file mode 100644 index d4c7eeb558..0000000000 --- a/haystack/nodes/answer_generator/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from haystack.nodes.answer_generator.base import BaseGenerator -from haystack.nodes.answer_generator.openai import OpenAIAnswerGenerator diff --git a/haystack/nodes/answer_generator/base.py b/haystack/nodes/answer_generator/base.py deleted file mode 100644 index 3aedff4ac8..0000000000 --- a/haystack/nodes/answer_generator/base.py +++ /dev/null @@ -1,226 +0,0 @@ -from abc import abstractmethod -from typing import Any, List, Optional, Dict, Union - -from tqdm import tqdm - -from haystack.errors import HaystackError -from haystack.schema import Answer, Document, MultiLabel -from haystack.nodes.base import BaseComponent - - -class BaseGenerator(BaseComponent): - """ - Abstract class for Generators - """ - - outgoing_edges = 1 - - def __init__(self, progress_bar: bool = True): - super().__init__() - self.progress_bar = progress_bar - - @abstractmethod - def predict(self, query: str, documents: List[Document], top_k: Optional[int], max_tokens: Optional[int]) -> Dict: - """ - Abstract method to generate answers. - - :param query: Query string. - :param documents: Related documents (for example, coming from a retriever) the answer should be based on. - :param top_k: Number of returned answers. - :param max_tokens: The maximum number of tokens the generated answer can have. - :return: Generated answers plus additional infos in a dict. - """ - pass - - def run( # type: ignore - self, - query: str, - documents: List[Document], - top_k: Optional[int] = None, - labels: Optional[MultiLabel] = None, - add_isolated_node_eval: bool = False, - max_tokens: Optional[int] = None, - ): # type: ignore - """ - :param query: Query string. - :param documents: List of Documents the answer should be based on. - :param top_k: The maximum number of answers to return. - :param labels: Labels to be used for evaluation. - :param add_isolated_node_eval: If True, the answer generator will be evaluated in isolation. - :param max_tokens: The maximum number of tokens the generated answer can have. - """ - if documents: - results = self.predict(query=query, documents=documents, top_k=top_k, max_tokens=max_tokens) - else: - results = {"answers": []} - - # run evaluation with "perfect" labels as node inputs to calculate "upper bound" metrics for just this node - if add_isolated_node_eval and labels is not None: - relevant_documents = list({label.document.id: label.document for label in labels.labels}.values()) - results_label_input = self.predict( - query=query, documents=relevant_documents, top_k=top_k, max_tokens=max_tokens - ) - results["answers_isolated"] = results_label_input["answers"] - - return results, "output_1" - - def run_batch( # type: ignore - self, - queries: List[str], - documents: Union[List[Document], List[List[Document]]], - top_k: Optional[int] = None, - labels: Optional[List[MultiLabel]] = None, - batch_size: Optional[int] = None, - add_isolated_node_eval: bool = False, - max_tokens: Optional[int] = None, - ): - """ - :param queries: List of query strings. - :param documents: List of list of Documents the answer should be based on. - :param top_k: The maximum number of answers to return. - :param labels: Labels to be used for evaluation. - :param add_isolated_node_eval: If True, the answer generator will be evaluated in isolation. - :param max_tokens: The maximum number of tokens the generated answer can have. - """ - results = self.predict_batch( - queries=queries, documents=documents, top_k=top_k, batch_size=batch_size, max_tokens=max_tokens - ) - - # run evaluation with "perfect" labels as node inputs to calculate "upper bound" metrics for just this node - if add_isolated_node_eval and labels is not None: - relevant_documents = [] - for labelx in labels: - # Deduplicate same Documents in a MultiLabel based on their Document ID and filter out empty Documents - relevant_docs_labels = list( - { - label.document.id: label.document - for label in labelx.labels - if not isinstance(label.document.content, str) or label.document.content.strip() != "" - }.values() - ) - relevant_documents.append(relevant_docs_labels) - results_label_input = self.predict_batch(queries=queries, documents=relevant_documents, top_k=top_k) - - results["answers_isolated"] = results_label_input["answers"] - return results, "output_1" - - def _flatten_docs(self, documents: List[Document]): - flat_docs_dict: Dict[str, Any] = {} - for document in documents: - for k, v in document.to_dict().items(): - if k not in flat_docs_dict: - flat_docs_dict[k] = [] - flat_docs_dict[k].append(v) - return flat_docs_dict - - def _create_answers( - self, generated_answers: List[str], documents: List[Document], prompt: Optional[str] = None - ) -> List[Answer]: - flat_docs_dict = self._flatten_docs(documents) - answers: List[Any] = [] - for generated_answer in generated_answers: - answers.append( - Answer( - answer=generated_answer, - document_ids=flat_docs_dict.get("id"), - type="generative", - meta={ - "doc_scores": flat_docs_dict.get("score"), - "content": flat_docs_dict.get("content"), - "titles": [d.get("name", "") for d in flat_docs_dict.get("meta", [])], - "doc_metas": flat_docs_dict.get("meta"), - "prompt": prompt, - }, - ) - ) - return answers - - def predict_batch( - self, - queries: List[str], - documents: Union[List[Document], List[List[Document]]], - top_k: Optional[int] = None, - batch_size: Optional[int] = None, - max_tokens: Optional[int] = None, - ): - """ - Generate the answer to the input queries. The generation will be conditioned on the supplied documents. - These documents can for example be retrieved via the Retriever. - - - If you provide a list containing a single query... - - - ... and a single list of Documents, the query will be applied to each Document individually. - - ... and a list of lists of Documents, the query will be applied to each list of Documents and the Answers - will be aggregated per Document list. - - - If you provide a list of multiple queries... - - - ... and a single list of Documents, each query will be applied to each Document individually. - - ... and a list of lists of Documents, each query will be applied to its corresponding list of Documents - and the Answers will be aggregated per query-Document pair. - - :param queries: List of queries. - :param documents: Related documents (for example, coming from a retriever) the answer should be based on. - Can be a single list of Documents or a list of lists of Documents. - :param top_k: Number of returned answers per query. - :param batch_size: Not applicable. - :param max_tokens: The maximum number of tokens the generated answer can have. - :return: Generated answers plus additional infos in a dict like this: - - ```python - {'queries': 'who got the first nobel prize in physics', - 'answers': - [{'query': 'who got the first nobel prize in physics', - 'answer': ' albert einstein', - 'meta': { 'doc_ids': [...], - 'doc_scores': [80.42758 ...], - 'doc_probabilities': [40.71379089355469, ... - 'content': ['Albert Einstein was a ...] - 'titles': ['"Albert Einstein"', ...] - }}]} - ``` - """ - # TODO: This method currently just calls the predict method multiple times, so there is room for improvement. - - results: Dict = {"queries": queries, "answers": []} - - single_doc_list = False - # Docs case 1: single list of Documents -> apply each query to all Documents - if len(documents) > 0 and isinstance(documents[0], Document): - single_doc_list = True - pb = tqdm(total=len(queries) * len(documents), disable=not self.progress_bar, desc="Generating answers") - for query in queries: - for doc in documents: - if not isinstance(doc, Document): - raise HaystackError(f"doc was of type {type(doc)}, but expected a Document.") - preds = self.predict(query=query, documents=[doc], top_k=top_k, max_tokens=max_tokens) - results["answers"].append(preds["answers"]) - pb.update(1) - pb.close() - - # Docs case 2: list of lists of Documents -> apply each query to corresponding list of Documents, if queries - # contains only one query, apply it to each list of Documents - elif len(documents) > 0 and isinstance(documents[0], list): - if len(queries) == 1: - queries = queries * len(documents) - if len(queries) != len(documents): - raise HaystackError("Number of queries must be equal to number of provided Document lists.") - pb = tqdm(total=min(len(queries), len(documents)), disable=not self.progress_bar, desc="Generating answers") - for query, cur_docs in zip(queries, documents): - if not isinstance(cur_docs, list): - raise HaystackError(f"cur_docs was of type {type(cur_docs)}, but expected a list of Documents.") - preds = self.predict(query=query, documents=cur_docs, top_k=top_k, max_tokens=max_tokens) - results["answers"].append(preds["answers"]) - pb.update(1) - pb.close() - - # Group answers by question in case of multiple queries and single doc list - if single_doc_list and len(queries) > 1: - answers_per_query = int(len(results["answers"]) / len(queries)) - answers = [] - for i in range(0, len(results["answers"]), answers_per_query): - answer_group = results["answers"][i : i + answers_per_query] - answers.append(answer_group) - results["answers"] = answers - - return results diff --git a/haystack/nodes/answer_generator/openai.py b/haystack/nodes/answer_generator/openai.py deleted file mode 100644 index 68edd16b03..0000000000 --- a/haystack/nodes/answer_generator/openai.py +++ /dev/null @@ -1,338 +0,0 @@ -import logging -import os -from typing import List, Optional, Tuple, Union -import warnings - -from haystack import Document -from haystack.environment import HAYSTACK_REMOTE_API_TIMEOUT_SEC -from haystack.nodes.answer_generator import BaseGenerator -from haystack.nodes.prompt import PromptTemplate -from haystack.utils.openai_utils import ( - load_openai_tokenizer, - openai_request, - _openai_text_completion_tokenization_details, - _check_openai_finish_reason, - check_openai_policy_violation, -) - -logger = logging.getLogger(__name__) - -OPENAI_TIMEOUT = float(os.environ.get(HAYSTACK_REMOTE_API_TIMEOUT_SEC, 30)) - - -class OpenAIAnswerGenerator(BaseGenerator): - """ - This component is now deprecated and will be removed in future versions. - Use `PromptNode` instead of `OpenAIAnswerGenerator`, - as explained in https://haystack.deepset.ai/tutorials/22_pipeline_with_promptnode. - - Uses the GPT-3 models from the OpenAI API to generate Answers based on the Documents it receives. - The Documents can come from a Retriever or you can supply them manually. - - To use this Node, you need an API key from an active OpenAI account. You can sign-up for an account - on the [OpenAI API website](https://openai.com/api/). - """ - - def __init__( - self, - api_key: str, - azure_base_url: Optional[str] = None, - azure_deployment_name: Optional[str] = None, - model: str = "text-davinci-003", - max_tokens: int = 50, - api_version: str = "2022-12-01", - top_k: int = 5, - temperature: float = 0.2, - presence_penalty: float = 0.1, - frequency_penalty: float = 0.1, - examples_context: Optional[str] = None, - examples: Optional[List[List[str]]] = None, - stop_words: Optional[List[str]] = None, - progress_bar: bool = True, - prompt_template: Optional[PromptTemplate] = None, - context_join_str: str = " ", - moderate_content: bool = False, - api_base: str = "https://api.openai.com/v1", - openai_organization: Optional[str] = None, - ): - """ - :param api_key: Your API key from OpenAI. It is required for this node to work. - :param azure_base_url: The base URL for the Azure OpenAI API. If not supplied, Azure OpenAI API will not be used. - This parameter is an OpenAI Azure endpoint, usually in the form `https://.openai.azure.com`. - :param azure_deployment_name: The name of the Azure OpenAI API deployment. If not supplied, Azure OpenAI API will not be used. - :param model: ID of the engine to use for generating the answer. You can select one of `"text-ada-001"`, - `"text-babbage-001"`, `"text-curie-001"`, or `"text-davinci-003"` - (from worst to best and from cheapest to most expensive). For more information about the models, - refer to the [OpenAI Documentation](https://platform.openai.com/docs/models/gpt-3). - :param max_tokens: The maximum number of tokens reserved for the generated Answer. - A higher number allows for longer answers without exceeding the max prompt length of the OpenAI model. - A lower number allows longer prompts with more documents passed as context, but the generated answer might be cut after max_tokens. - :param api_version: The version of the Azure OpenAI API to use. The default is `2022-12-01` version. - :param top_k: Number of generated Answers. - :param temperature: What sampling temperature to use. Higher values mean the model will take more risks and - value 0 (argmax sampling) works better for scenarios with a well-defined Answer. - :param presence_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they have already appeared - in the text. This increases the model's likelihood to talk about new topics. For more information about frequency and presence penalties, see - [parameter details in OpenAI](https://platform.openai.com/docs/api-reference/parameter-details). - :param frequency_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing - frequency in the text so far, decreasing the model's likelihood to repeat the same line - verbatim. - [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details) - :param examples_context: A text snippet containing the contextual information used to generate the Answers for - the examples you provide. - If not supplied, the default from OpenAI API docs is used: - `"In 2017, U.S. life expectancy was 78.6 years."` - :param examples: List of (question, answer) pairs that helps steer the model towards the tone and answer - format you'd like. We recommend adding 2 to 3 examples. - If not supplied, the default from OpenAI API docs is used: - `[["Q: What is human life expectancy in the United States?", "A: 78 years."]]` - :param stop_words: Up to four sequences where the API stops generating further tokens. The returned text does not contain the stop sequence. - If you don't provide any stop words, the default value from OpenAI API docs is used: `["\\n", "<|endoftext|>"]`. - :param prompt_template: A PromptTemplate that tells the model how to generate answers given a - `context` and `query` supplied at runtime. The `context` is automatically constructed at runtime from a - list of provided documents. Use `example_context` and a list of `examples` to provide the model with examples to steer it towards the tone and answer format you would like. - If not supplied, the default prompt template is: - ```python - PromptTemplate( - "Please answer the question according to the above context." - "\\n===\\nContext: {examples_context}\\n===\\n{examples}\\n\\n" - "===\\nContext: {context}\\n===\\n{query}", - ) - ``` - To learn how variables, such as '{context}', are substituted in the prompt text, see - [PromptTemplate](https://docs.haystack.deepset.ai/docs/prompt_node#template-structure). - :param context_join_str: The separation string used to join the input documents to create the context - used by the PromptTemplate. - :param moderate_content: Whether to filter input and generated answers for potentially sensitive content - using the [OpenAI Moderation API](https://platform.openai.com/docs/guides/moderation). If the input or - answers are flagged, an empty list is returned in place of the answers. - :param api_base: The base URL for the OpenAI API, defaults to `"https://api.openai.com/v1"`. - :param openai_organization: The OpenAI-Organization ID, defaults to `None`. For more details, see see OpenAI - [documentation](https://platform.openai.com/docs/api-reference/requesting-organization). - """ - - warnings.warn( - "`OpenAIAnswerGenerator component is deprecated and will be removed in future versions. Use `PromptNode` " - "instead of `OpenAIAnswerGenerator`.", - category=DeprecationWarning, - ) - - super().__init__(progress_bar=progress_bar) - if (examples is None and examples_context is not None) or (examples is not None and examples_context is None): - logger.warning( - "If providing examples or examples_context, we recommend providing both of them " - "so the examples correctly refer to the examples_context." - ) - if examples_context is None: - examples_context = "In 2017, U.S. life expectancy was 78.6 years." - if examples is None: - examples = [["Q: What is human life expectancy in the United States?", "A: 78 years."]] - if stop_words is None: - stop_words = ["\n", "<|endoftext|>"] - if prompt_template is None: - prompt_template = PromptTemplate( - "Please answer the question according to the above context." - "\n===\nContext: {examples_context}\n===\n{examples}\n\n" - "===\nContext: {context}\n===\n{query}" - ) - else: - # Check for required prompts - required_params = ["context", "query"] - if not all(p in prompt_template.prompt_params for p in required_params): - raise ValueError( - "The OpenAIAnswerGenerator requires a PromptTemplate that has `context` and " - "`query` in its `prompt_params`. Supply a different `prompt_template` or " - "use the default one." - ) - - # Check for unsupported prompt parameters - optional_params = ["examples_context", "examples"] - unknown_params = [] - for p in prompt_template.prompt_params: - if p not in set(required_params + optional_params): - unknown_params.append(p) - if len(unknown_params) > 1: - raise ValueError( - f"The provided PromptTemplate has the prompt parameters, {unknown_params}, that are not supported " - f"by the OpenAIAnswerGenerator. The only prompt parameters that are supported are " - f"`examples_context`, `examples`, `context`, and `query`." - ) - - self.api_key = api_key - self.azure_base_url = azure_base_url - self.azure_deployment_name = azure_deployment_name - self.api_version = api_version - self.api_base = api_base - self.model = model - self.max_tokens = max_tokens - self.top_k = top_k - self.temperature = temperature - self.presence_penalty = presence_penalty - self.frequency_penalty = frequency_penalty - self.examples_context = examples_context - self.examples = examples - self.stop_words = stop_words - self.prompt_template = prompt_template - self.context_join_str = context_join_str - self.using_azure = self.azure_deployment_name is not None and self.azure_base_url is not None - self.moderate_content = moderate_content - self.openai_organization = openai_organization - - tokenizer_name, max_tokens_limit = _openai_text_completion_tokenization_details(model_name=self.model) - - self.MAX_TOKENS_LIMIT = max_tokens_limit - self._tokenizer = load_openai_tokenizer(tokenizer_name=tokenizer_name) - - def predict( - self, - query: str, - documents: List[Document], - top_k: Optional[int] = None, - max_tokens: Optional[int] = None, - timeout: Union[float, Tuple[float, float]] = OPENAI_TIMEOUT, - ): - """ - Use the loaded QA model to generate Answers for a query based on the Documents it receives. - - Returns dictionaries containing Answers. - Note that OpenAI doesn't return scores for those Answers. - - Example: - ```python - { - 'query': 'Who is the father of Arya Stark?', - 'answers':[Answer( - 'answer': 'Eddard,', - 'score': None, - ),... - ] - } - ``` - - :param query: The query you want to provide. It's a string. - :param documents: List of Documents in which to search for the Answer. - :param top_k: The maximum number of Answers to return. - :param max_tokens: The maximum number of tokens the generated Answer can have. - :param timeout: How many seconds to wait for the server to send data before giving up, - as a float, or a :ref:`(connect timeout, read timeout) ` tuple. - Defaults to 10 seconds. - :return: Dictionary containing query and Answers. - """ - if top_k is None: - top_k = self.top_k - - # convert input to OpenAI format - prompt, input_docs = self._build_prompt_within_max_length(query=query, documents=documents) - logger.debug("Prompt being sent to OpenAI API with prompt %s.", prompt) - - payload = { - "model": self.model, - "prompt": prompt, - "max_tokens": max_tokens or self.max_tokens, - "stop": self.stop_words, - "n": top_k, - "temperature": self.temperature, - "presence_penalty": self.presence_penalty, - "frequency_penalty": self.frequency_penalty, - } - if self.using_azure: - url = f"{self.azure_base_url}/openai/deployments/{self.azure_deployment_name}/completions?api-version={self.api_version}" - else: - url = f"{self.api_base}/completions" - - headers = {"Content-Type": "application/json"} - if self.using_azure: - headers["api-key"] = self.api_key - else: - headers["Authorization"] = f"Bearer {self.api_key}" - if self.openai_organization: - headers["OpenAI-Organization"] = self.openai_organization - - if self.moderate_content and check_openai_policy_violation(input=prompt, headers=headers): - logger.info("Prompt '%s' will not be sent to OpenAI due to potential policy violation.", prompt) - return {"query": query, "answers": []} - - logger.debug("Prompt being sent to OpenAI API with prompt %s.", prompt) - res = openai_request(url=url, headers=headers, payload=payload, timeout=timeout) - _check_openai_finish_reason(result=res, payload=payload) - generated_answers = [ans["text"] for ans in res["choices"]] - if self.moderate_content and check_openai_policy_violation(input=generated_answers, headers=headers): - logger.info( - "Generated answers '%s' will not be returned due to potential policy violation.", generated_answers - ) - return {"query": query, "answers": []} - answers = self._create_answers(generated_answers, input_docs, prompt=prompt) - result = {"query": query, "answers": answers} - return result - - @staticmethod - def _create_context(documents: List[Document], join_str: str = " ") -> str: - """Join the documents to create a single context to be used in the PromptTemplate.""" - doc_contents = [doc.content for doc in documents] - # We reverse the docs to put the most relevant documents at the bottom of the context - context = join_str.join(reversed(doc_contents)) - return context - - def _fill_prompt(self, query: str, documents: List[Document]) -> str: - """Fills in the `prompt_template` with its `prompt_params` and returns the full prompt.""" - example_prompts = "\n---\n".join([f"{query}\n{answer}" for query, answer in self.examples]) - qa_prompt = f"Q: {query}\nA:" - - kwargs = {"context": self._create_context(documents, join_str=self.context_join_str), "query": qa_prompt} - if ( - "examples_context" in self.prompt_template.prompt_params - and "examples" in self.prompt_template.prompt_params - ): - kwargs["examples_context"] = self.examples_context - kwargs["examples"] = example_prompts - full_prompt = next(self.prompt_template.fill(**kwargs)) - return full_prompt - - def _build_prompt_within_max_length(self, query: str, documents: List[Document]) -> Tuple[str, List[Document]]: - """ - Builds the prompt for the GPT-3 model so that it can generate an Answer. If the prompt is too long based on the - MAX_TOKENS_LIMIT of the OpenAI model and `max_tokens` you specify, then documents (used to - construct the context) are thrown away until the prompt length fits within the MAX_TOKENS_LIMIT. - """ - full_prompt = self._fill_prompt(query, documents) - n_full_prompt_tokens = len(self._tokenizer.encode(full_prompt)) - - # for length restrictions of prompt see: https://platform.openai.com/docs/api-reference/completions/create#completions/create-max_tokens - leftover_token_len = self.MAX_TOKENS_LIMIT - n_full_prompt_tokens - self.max_tokens - - # Trim down the prompt (by removing documents) until it fits the models MAX_TOKENS_LIMIT - input_docs = documents - skipped_docs = 0 - # If leftover_token_len is negative we have gone past the MAX_TOKENS_LIMIT and the prompt must be trimmed - if leftover_token_len < 0: - n_skipped_tokens = 0 - # Reversing the order of documents b/c we want to throw away less relevant docs first - for doc in reversed(documents): - skipped_docs += 1 - n_skipped_tokens += len(self._tokenizer.encode(doc.content)) - - # Only skip enough tokens to fit within the MAX_TOKENS_LIMIT - if n_skipped_tokens >= abs(leftover_token_len): - break - - # Throw away least relevant docs - input_docs = documents[:-skipped_docs] - full_prompt = self._fill_prompt(query, input_docs) - n_full_prompt_tokens = len(self._tokenizer.encode(full_prompt)) - - if len(input_docs) == 0: - logger.warning( - "Skipping all of the provided Documents, as none of them fits the maximum token limit of %s. " - "The generated answers will therefore not be conditioned on any context.", - self.MAX_TOKENS_LIMIT, - ) - elif skipped_docs >= 1: - logger.warning( - "Skipping %s of the provided Documents, as using them would exceed the maximum token limit of %s.", - skipped_docs, - self.MAX_TOKENS_LIMIT, - ) - - logger.debug("Number of tokens in full prompt: %s", n_full_prompt_tokens) - logger.debug("Full prompt: %s", full_prompt) - return full_prompt, input_docs diff --git a/haystack/nodes/connector/crawler.py b/haystack/nodes/connector/crawler.py index 415f3b6cf7..bfb6f6db4a 100644 --- a/haystack/nodes/connector/crawler.py +++ b/haystack/nodes/connector/crawler.py @@ -6,15 +6,16 @@ import sys import time from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Set +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from urllib.parse import urlparse +from haystack.lazy_imports import LazyImport from haystack.nodes.base import BaseComponent from haystack.schema import Document -from haystack.lazy_imports import LazyImport with LazyImport("Run 'pip install farm-haystack[crawler]'") as selenium_import: - from selenium import webdriver + from selenium import webdriver as selenium_webdriver + from selenium.webdriver.remote.webdriver import WebDriver from selenium.common.exceptions import StaleElementReferenceException from selenium.webdriver.chrome.options import Options from selenium.webdriver.chrome.service import Service @@ -53,6 +54,7 @@ def __init__( file_path_meta_field_name: Optional[str] = None, crawler_naming_function: Optional[Callable[[str, str], str]] = None, webdriver_options: Optional[List[str]] = None, + webdriver: Optional["WebDriver"] = None, ): """ Init object with basic params for crawling (can be overwritten later). @@ -96,10 +98,27 @@ def __init__( This option enables remote debug over HTTP. See [Chromium Command Line Switches](https://peter.sh/experiments/chromium-command-line-switches/) for more details on the available options. If your crawler fails, raising a `selenium.WebDriverException`, this [Stack Overflow thread](https://stackoverflow.com/questions/50642308/webdriverexception-unknown-error-devtoolsactiveport-file-doesnt-exist-while-t) can be helpful. Contains useful suggestions for webdriver_options. + :param webdriver: A pre-configured Selenium WebDriver. + When webdriver_options is not sufficient, use this parameter to override the whole web driver. This lets you use different engines other than the default Chrome. """ selenium_import.check() super().__init__() + self.urls = urls + self.crawler_depth = crawler_depth + self.filter_urls = filter_urls + self.overwrite_existing_files = overwrite_existing_files + self.id_hash_keys = id_hash_keys + self.extract_hidden_text = extract_hidden_text + self.loading_wait_time = loading_wait_time + self.crawler_naming_function = crawler_naming_function + self.output_dir = output_dir + self.file_path_meta_field_name = file_path_meta_field_name + + if webdriver is not None: + self.driver = webdriver + return + IN_COLAB = "google.colab" in sys.modules IN_AZUREML = os.environ.get("AZUREML_ENVIRONMENT_IMAGE", None) == "True" IN_WINDOWS = sys.platform in ["win32", "cygwin"] @@ -118,18 +137,10 @@ def __init__( options = Options() for option in set(webdriver_options): options.add_argument(option) + if self.output_dir: + options.add_experimental_option("prefs", {"download.default_directory": str(self.output_dir)}) - self.driver = webdriver.Chrome(service=Service(), options=options) - self.urls = urls - self.crawler_depth = crawler_depth - self.filter_urls = filter_urls - self.overwrite_existing_files = overwrite_existing_files - self.id_hash_keys = id_hash_keys - self.extract_hidden_text = extract_hidden_text - self.loading_wait_time = loading_wait_time - self.crawler_naming_function = crawler_naming_function - self.output_dir = output_dir - self.file_path_meta_field_name = file_path_meta_field_name + self.driver = selenium_webdriver.Chrome(service=Service(), options=options) def __del__(self): self.driver.quit() @@ -337,6 +348,7 @@ def _crawl_urls( if loading_wait_time is not None: time.sleep(loading_wait_time) el = self.driver.find_element(by=By.TAG_NAME, value="body") + if extract_hidden_text: text = el.get_attribute("textContent") else: @@ -469,9 +481,11 @@ def _extract_sublinks_from_url( loading_wait_time: Optional[int] = None, ) -> Set[str]: self.driver.get(base_url) + if loading_wait_time is not None: time.sleep(loading_wait_time) a_elements = self.driver.find_elements(by=By.XPATH, value="//a[@href]") + sub_links = set() filter_pattern = re.compile("|".join(filter_urls)) if filter_urls is not None else None diff --git a/haystack/nodes/document_classifier/transformers.py b/haystack/nodes/document_classifier/transformers.py index f5c2a84f6b..c58be4abb4 100644 --- a/haystack/nodes/document_classifier/transformers.py +++ b/haystack/nodes/document_classifier/transformers.py @@ -149,7 +149,7 @@ def __init__( model=model_name_or_path, tokenizer=tokenizer, revision=model_version, - use_auth_token=use_auth_token, + token=use_auth_token, device=resolved_devices[0], ) elif task == "text-classification": @@ -160,7 +160,7 @@ def __init__( device=resolved_devices[0], revision=model_version, top_k=top_k, - use_auth_token=use_auth_token, + token=use_auth_token, ) self.top_k = top_k self.labels = labels diff --git a/haystack/nodes/file_classifier/file_type.py b/haystack/nodes/file_classifier/file_type.py index 3a91a89de3..67a3d8802b 100644 --- a/haystack/nodes/file_classifier/file_type.py +++ b/haystack/nodes/file_classifier/file_type.py @@ -26,7 +26,9 @@ class FileTypeClassifier(BaseComponent): outgoing_edges = len(DEFAULT_TYPES) - def __init__(self, supported_types: Optional[List[str]] = None, full_analysis: bool = False): + def __init__( + self, supported_types: Optional[List[str]] = None, full_analysis: bool = False, raise_on_error: bool = True + ): """ Node that sends out files on a different output edge depending on their extension. @@ -35,9 +37,11 @@ def __init__(self, supported_types: Optional[List[str]] = None, full_analysis: b You can't use lists with duplicate elements. :param full_analysis: If True, the whole file is analyzed to determine the file type. If False, only the first 2049 bytes are analyzed. + :param raise_on_error: If True, the node will raise an exception if the file type is not supported. """ self.full_analysis = full_analysis self._default_types = False + self._raise_on_error = raise_on_error if supported_types is None: self._default_types = True supported_types = DEFAULT_TYPES @@ -121,6 +125,16 @@ def run(self, file_paths: Union[Path, List[Path], str, List[str], List[Union[Pat try: index = self.supported_types.index(extension) + 1 except ValueError: + if self._raise_on_error is False: + logger.warning( + "Unsupported files of type '%s' (%s) found. " + "Unsupported file types will be ignored during indexing as `raise_on_error` is set to `False`. " + "The supported types are: %s. ", + extension, + paths[0], + self.supported_types, + ) + return {"file_paths": paths}, "output_dead_end" raise ValueError( f"Files of type '{extension}' ({paths[0]}) are not supported. " f"The supported types are: {self.supported_types}. " diff --git a/haystack/nodes/file_converter/__init__.py b/haystack/nodes/file_converter/__init__.py index 76a5dd1aa3..a37bc4b499 100644 --- a/haystack/nodes/file_converter/__init__.py +++ b/haystack/nodes/file_converter/__init__.py @@ -9,16 +9,6 @@ from haystack.nodes.file_converter.txt import TextConverter from haystack.nodes.file_converter.azure import AzureConverter from haystack.nodes.file_converter.parsr import ParsrConverter - - -try: - with LazyImport() as fitz_import: - # Try to use PyMuPDF, if not available fall back to xpdf - from haystack.nodes.file_converter.pdf import PDFToTextConverter # type: ignore - - fitz_import.check() -except (ModuleNotFoundError, ImportError): - from haystack.nodes.file_converter.pdf_xpdf import PDFToTextConverter # type: ignore # pylint: disable=reimported,ungrouped-imports - +from haystack.nodes.file_converter.pdf_xpdf import PDFToTextConverter from haystack.nodes.file_converter.markdown import MarkdownConverter from haystack.nodes.file_converter.image import ImageToTextConverter diff --git a/haystack/nodes/file_converter/base.py b/haystack/nodes/file_converter/base.py index d5fbb3fc58..934cca4689 100644 --- a/haystack/nodes/file_converter/base.py +++ b/haystack/nodes/file_converter/base.py @@ -158,6 +158,7 @@ def run( # type: ignore valid_languages: Optional[List[str]] = None, encoding: Optional[str] = "UTF-8", id_hash_keys: Optional[List[str]] = None, + raise_on_failure: bool = True, ): """ Extract text from a file. @@ -188,6 +189,7 @@ def run( # type: ignore attributes. If you want to ensure you don't have duplicate documents in your DocumentStore but texts are not unique, you can modify the metadata and pass e.g. `"meta"` to this field (e.g. [`"content"`, `"meta"`]). In this case the id will be generated by using the content and the defined metadata. + :param raise_on_failure: If true, raises an exception if the conversion of a single file fails. If False, skips the file without failing. """ if known_ligatures is None: known_ligatures = KNOWN_LIGATURES @@ -199,17 +201,24 @@ def run( # type: ignore meta = [meta] * len(file_paths) documents: list = [] + failed_paths: list = [] for file_path, file_meta in tqdm( zip(file_paths, meta), total=len(file_paths), disable=not self.progress_bar, desc="Converting files" ): - documents += self.convert( - file_path=file_path, - meta=file_meta, - remove_numeric_tables=remove_numeric_tables, - valid_languages=valid_languages, - encoding=encoding, - id_hash_keys=id_hash_keys, - ) + try: + documents += self.convert( + file_path=file_path, + meta=file_meta, + remove_numeric_tables=remove_numeric_tables, + valid_languages=valid_languages, + encoding=encoding, + id_hash_keys=id_hash_keys, + ) + except Exception as e: + if raise_on_failure: + raise e + failed_paths.append(str(file_path)) + continue # Cleanup ligatures for document in documents: @@ -217,6 +226,9 @@ def run( # type: ignore if document.content is not None: document.content = document.content.replace(ligature, letters) + if failed_paths: + logger.warning("Conversion of the following file paths failed: %s", ",".join(failed_paths)) + result = {"documents": documents} return result, "output_1" diff --git a/haystack/nodes/file_converter/docx.py b/haystack/nodes/file_converter/docx.py index ae59f13919..eb3ff07933 100644 --- a/haystack/nodes/file_converter/docx.py +++ b/haystack/nodes/file_converter/docx.py @@ -74,8 +74,8 @@ def convert( if id_hash_keys is None: id_hash_keys = self.id_hash_keys - file = docx.Document(file_path) # Creating word reader object. - paragraphs = [para.text for para in file.paragraphs] + file = docx.Document(file_path) # type: ignore + paragraphs = [para.text for para in file.paragraphs] # type: ignore text = "\n".join(paragraphs) document = Document(content=text, meta=meta, id_hash_keys=id_hash_keys) return [document] diff --git a/haystack/nodes/file_converter/image.py b/haystack/nodes/file_converter/image.py index 1e4db2d392..9c253579c1 100644 --- a/haystack/nodes/file_converter/image.py +++ b/haystack/nodes/file_converter/image.py @@ -119,7 +119,7 @@ def convert( file_path = Path(file_path) image = Image.open(file_path) - pages = self._image_to_text(image) + pages = self._image_to_text(image) # type: ignore if remove_numeric_tables is None: remove_numeric_tables = self.remove_numeric_tables if valid_languages is None: diff --git a/haystack/nodes/file_converter/pdf.py b/haystack/nodes/file_converter/pdf.py deleted file mode 100644 index e5348671dc..0000000000 --- a/haystack/nodes/file_converter/pdf.py +++ /dev/null @@ -1,307 +0,0 @@ -import logging -import os -import warnings -from concurrent.futures import ProcessPoolExecutor -from multiprocessing import cpu_count -from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Union - -import fitz -from more_itertools import divide - -from haystack.nodes.file_converter.base import BaseConverter -from haystack.schema import Document - - -logger = logging.getLogger(__name__) - - -class PDFToTextConverter(BaseConverter): - def __init__( - self, - remove_numeric_tables: bool = False, - valid_languages: Optional[List[str]] = None, - id_hash_keys: Optional[List[str]] = None, - encoding: Optional[str] = None, - keep_physical_layout: Optional[bool] = None, - sort_by_position: bool = False, - ocr: Optional[Literal["auto", "full"]] = None, - ocr_language: str = "eng", - multiprocessing: Union[bool, int] = True, - ) -> None: - """ - :param remove_numeric_tables: This option uses heuristics to remove numeric rows from the tables. - The tabular structures in documents might be noise for the reader model if it - does not have table parsing capability for finding answers. However, tables - may also have long strings that could possible candidate for searching answers. - The rows containing strings are thus retained in this option. - :param valid_languages: validate languages from a list of languages specified in the ISO 639-1 - (https://en.wikipedia.org/wiki/ISO_639-1) format. - This option can be used to add test for encoding errors. If the extracted text is - not one of the valid languages, then it might likely be encoding error resulting - in garbled text. - :param id_hash_keys: Generate the document id from a custom list of strings that refer to the document's - attributes. If you want to ensure you don't have duplicate documents in your DocumentStore but texts are - not unique, you can modify the metadata and pass e.g. `"meta"` to this field (e.g. [`"content"`, `"meta"`]). - In this case the id will be generated by using the content and the defined metadata. - :param encoding: This parameter is being deprecated. - It will be automatically detected by PyMuPDF. - :param keep_physical_layout: This parameter is being deprecated. - :param sort_by_position: Specifies whether to sort the extracted text by positional coordinates or logical reading order. - If set to True, the text is sorted first by vertical position, and then by horizontal position. - If set to False (default), the logical reading order in the PDF is used. - :param ocr: Specifies whether to use OCR to extract text from images in the PDF. If set to "auto", OCR is used only to extract text - from images and integrate into the existing text. If set to "full", OCR is used to extract text from the entire PDF. - :param ocr_language: Specifies the language to use for OCR. The default language is English, which language code is `eng`. - For a list of supported languages and the respective codes access https://tesseract-ocr.github.io/tessdoc/Data-Files-in-different-versions.html. - You can combine multiple languages by passing a string with the language codes separated by `+`. For example, to use English and German, pass `eng+deu`. - :param multiprocessing: We use multiprocessing to speed up PyMuPDF conversion, you can disable it by setting it to False. - If set to True (the default value), the total number of cores is used. To specify the number of cores to use, set it to an integer. - """ - super().__init__( - remove_numeric_tables=remove_numeric_tables, valid_languages=valid_languages, id_hash_keys=id_hash_keys - ) - - self.sort_by_position = sort_by_position - self.multiprocessing = multiprocessing - self.ocr = ocr - self.ocr_language = ocr_language - - if ocr is not None: - if ocr not in ["auto", "full"]: - raise ValueError("The ocr parameter must be either 'auto' or 'full'.") - self._check_tessdata() - - if encoding: - warnings.warn( - "The encoding parameter is being deprecated. It will be automatically detected by PyMuPDF.", - DeprecationWarning, - ) - - if keep_physical_layout: - warnings.warn("The keep_physical_layout parameter is being deprecated.", DeprecationWarning) - - def convert( - self, - file_path: Path, - meta: Optional[Dict[str, Any]] = None, - remove_numeric_tables: Optional[bool] = None, - valid_languages: Optional[List[str]] = None, - encoding: Optional[str] = None, - id_hash_keys: Optional[List[str]] = None, - start_page: Optional[int] = None, - end_page: Optional[int] = None, - keep_physical_layout: Optional[bool] = None, - sort_by_position: Optional[bool] = None, - ocr: Optional[Literal["auto", "full"]] = None, - ocr_language: Optional[str] = None, - multiprocessing: Optional[Union[bool, int]] = None, - ) -> List[Document]: - """ - Extract text from a PDF file and convert it to a Document. - :param file_path: Path to the .pdf file you want to convert - :param meta: Optional dictionary with metadata that shall be attached to all resulting documents. - Can be any custom keys and values. - :param remove_numeric_tables: This option uses heuristics to remove numeric rows from the tables. - The tabular structures in documents might be noise for the reader model if it - does not have table parsing capability for finding answers. However, tables - may also have long strings that could possible candidate for searching answers. - The rows containing strings are thus retained in this option. - :param valid_languages: validate languages from a list of languages specified in the ISO 639-1 - (https://en.wikipedia.org/wiki/ISO_639-1) format. - This option can be used to add test for encoding errors. If the extracted text is - not one of the valid languages, then it might likely be encoding error resulting - in garbled text. - :param encoding: This parameter is being deprecated. - It will be automatically detected by PyMuPDF. - :param keep_physical_layout: This parameter is being deprecated. - :param sort_by_position: Specifies whether to sort the extracted text by positional coordinates or logical reading order. - If set to True, the text is sorted first by vertical position, and then by horizontal position. - If set to False (default), the logical reading order in the PDF is used. - :param id_hash_keys: Generate the document id from a custom list of strings that refer to the document's - attributes. If you want to ensure you don't have duplicate documents in your DocumentStore but texts are - not unique, you can modify the metadata and pass e.g. `"meta"` to this field (e.g. [`"content"`, `"meta"`]). - In this case the id will be generated by using the content and the defined metadata. - :param start_page: The page number where to start the conversion - :param end_page: The page number where to end the conversion. - :param ocr: Specifies whether to use OCR to extract text from images in the PDF. If set to "auto", OCR is used only to extract text - from images and integrate into the existing text. If set to "full", OCR is used to extract text from the entire PDF. - To use this feature you must install Tesseract-OCR. For more information, see https://github.com/tesseract-ocr/tesseract#installing-tesseract. - :param ocr_language: Specifies the language to use for OCR. The default language is English, which language code is `eng`. - For a list of supported languages and the respective codes access https://tesseract-ocr.github.io/tessdoc/Data-Files-in-different-versions.html. - You can combine multiple languages by passing a string with the language codes separated by `+`. For example, to use English and German, pass `eng+deu`. - :param multiprocessing: We use multiprocessing to speed up PyMuPDF conversion, you can disable it by setting it to False. - If set to None (the default value), the value defined in the class initialization is used. - If set to True, the total number of cores is used. To specify the number of cores to use, set it to an integer. - """ - if remove_numeric_tables is None: - remove_numeric_tables = self.remove_numeric_tables - if valid_languages is None: - valid_languages = self.valid_languages - if id_hash_keys is None: - id_hash_keys = self.id_hash_keys - if multiprocessing is None: - multiprocessing = self.multiprocessing - if sort_by_position is None: - sort_by_position = self.sort_by_position - if ocr is None: - ocr = self.ocr - if ocr_language is None: - ocr_language = self.ocr_language - - if encoding: - warnings.warn( - "The encoding parameter is being deprecated. It will be automatically detected by PyMuPDF.", - DeprecationWarning, - ) - - if keep_physical_layout: - warnings.warn("The keep_physical_layout parameter is being deprecated.", DeprecationWarning) - - if ocr is not None: - if ocr not in ["auto", "full"]: - raise ValueError("The ocr parameter must be either 'auto' or 'full'.") - self._check_tessdata() - - pages = self._read_pdf( - file_path, - sort_by_position=sort_by_position, - start_page=start_page, - end_page=end_page, - ocr=ocr, - ocr_language=ocr_language, - multiprocessing=multiprocessing, - ) - - cleaned_pages = [] - for page in pages: - lines = page.splitlines() - cleaned_lines = [] - for line in lines: - words = line.split() - digits = [word for word in words if any(i.isdigit() for i in word)] - - # remove lines having > 40% of words as digits AND not ending with a period(.) - if ( - remove_numeric_tables - and words - and len(digits) / len(words) > 0.4 - and not line.strip().endswith(".") - ): - logger.debug("Removing line '%s' from %s", line, file_path) - continue - cleaned_lines.append(line) - - page = "\n".join(cleaned_lines) - cleaned_pages.append(page) - - if valid_languages: - document_text = "".join(cleaned_pages) - if not self.validate_language(document_text, valid_languages): - logger.warning( - "The language for %s is not one of %s. The file may not have " - "been decoded in the correct text format.", - file_path, - valid_languages, - ) - - text = "\f".join(cleaned_pages) - document = Document(content=text, meta=meta, id_hash_keys=id_hash_keys) - return [document] - - def _check_tessdata(self): - if os.getenv("TESSDATA_PREFIX") is None: - raise EnvironmentError( - """ - To enable OCR support via PDFToTextConverter, you need to install Tesseract: - - Windows: choco install tesseract-ocr - - Linux (Ubuntu): sudo apt-get install tesseract-ocr - - Mac: brew install tesseract - After that, you need to set the environment variable TESSDATA_PREFIX to the path - of your Tesseract data directory. Typically this is: - - Windows: C:\\Program Files\\Tesseract-OCR\\tessdata - - Linux (Ubuntu): /usr/share/tesseract-ocr/4.00/tessdata - - Mac (Intel): /usr/local/Cellar/tesseract/5.3.0_1/share/tessdata - - Mac (M1/M2): /opt/homebrew/Cellar/tesseract/5.3.0_1/share/tessdata - """ - ) - - def _get_text_parallel(self, page_mp): - idx, filename, parts, sort_by_position, ocr, ocr_language = page_mp - - doc = fitz.open(filename) - - text = "" - for i in parts[idx]: - page = doc[i] - partial_tp = None - if ocr is not None: - full = ocr == "full" - partial_tp = page.get_textpage_ocr(flags=0, full=full, dpi=300, language=ocr_language) - text += page.get_text("text", textpage=partial_tp, sort=sort_by_position) + "\f" - - return text - - def _read_pdf( - self, - file_path: Path, - ocr_language: str, - sort_by_position: bool = False, - start_page: Optional[int] = None, - end_page: Optional[int] = None, - ocr: Optional[Literal["auto", "full"]] = None, - multiprocessing: Optional[Union[bool, int]] = None, - ) -> List[str]: - """ - Extract pages from the pdf file at file_path. - - :param file_path: path of the pdf file - :param sort_by_position: Specifies whether to sort the extracted text by positional coordinates or logical reading order. - If set to True, the text is sorted first by vertical position, and then by horizontal position. - If set to False (default), the logical reading order in the PDF is used. - :param start_page: The page number where to start the conversion, starting from 1. - :param end_page: The page number where to end the conversion. - :param encoding: This parameter is being deprecated. - It will be automatically detected by PyMuPDF. - :param multiprocessing: We use multiprocessing to speed up PyMuPDF conversion, you can disable it by setting it to False. - If set to None (the default value), the value defined in the class initialization is used. - If set to True, the total number of cores is used. To specify the number of cores to use, set it to an integer. - """ - if start_page is None: - start_page = 0 - else: - start_page = start_page - 1 - - doc = fitz.open(file_path) - page_count = int(doc.page_count) - - if end_page is None or (end_page is not None and end_page > page_count): - end_page = page_count - - document = "" - - if not multiprocessing: - for i in range(start_page, end_page): - page = doc[i] - partial_tp = None - if ocr is not None: - full = ocr == "full" - partial_tp = page.get_textpage_ocr(flags=0, full=full, dpi=300, language=ocr_language) - document += page.get_text("text", textpage=partial_tp, sort=sort_by_position) + "\f" - else: - cpu = cpu_count() if isinstance(multiprocessing, bool) else multiprocessing - page_list = list(range(start_page, end_page)) - cpu = cpu if len(page_list) > cpu else len(page_list) - parts = divide(cpu, page_list) - pages_mp = [(i, file_path, parts, sort_by_position, ocr, ocr_language) for i in range(cpu)] - - with ProcessPoolExecutor(max_workers=cpu) as pool: - results = pool.map(self._get_text_parallel, pages_mp) - for page in results: - document += page - - document = "\f" * start_page + document # tracking skipped pages for correct page numbering - pages = document.split("\f") - pages = pages[:-1] # the last page in the split is always empty. - - return pages diff --git a/haystack/nodes/other/join_docs.py b/haystack/nodes/other/join_docs.py index 274e90a38d..4c71fdf7c5 100644 --- a/haystack/nodes/other/join_docs.py +++ b/haystack/nodes/other/join_docs.py @@ -1,7 +1,7 @@ import logging from collections import defaultdict from math import inf -from typing import List, Optional +from typing import List, Optional, Dict, Tuple from haystack.nodes.other.join import JoinNode from haystack.schema import Document @@ -58,8 +58,13 @@ def __init__( self.top_k_join = top_k_join self.sort_by_score = sort_by_score - def run_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None): # type: ignore + def run_accumulated(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]: results = [inp["documents"] for inp in inputs] + + # Check if all results are non-empty + if all(not res for res in results): + return {"documents": [], "labels": inputs[0].get("labels", None)}, "output_1" + document_map = {doc.id: doc for result in results for doc in result} if self.join_mode == "concatenate": @@ -98,7 +103,7 @@ def run_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None): return output, "output_1" - def run_batch_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None): # type: ignore + def run_batch_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]: # Join single document lists if isinstance(inputs[0]["documents"][0], Document): return self.run(inputs=inputs, top_k_join=top_k_join) @@ -117,13 +122,13 @@ def run_batch_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = return output, "output_1" - def _concatenate_results(self, results, document_map): + def _concatenate_results(self, results: List[List[Document]], document_map: Dict) -> Dict[str, float]: """ Concatenates multiple document result lists. Return the documents with the higher score. """ list_id = list(document_map.keys()) - scores_map = {} + scores_map: Dict[str, float] = {} for idx in list_id: tmp = [] for result in results: @@ -134,11 +139,11 @@ def _concatenate_results(self, results, document_map): scores_map.update({idx: item_best_score.score}) return scores_map - def _calculate_comb_sum(self, results): + def _calculate_comb_sum(self, results: List[List[Document]]) -> Dict[str, float]: """ Calculates a combination sum by multiplying each score by its weight. """ - scores_map = defaultdict(int) + scores_map: Dict[str, float] = defaultdict(float) weights = self.weights if self.weights else [1 / len(results)] * len(results) for result, weight in zip(results, weights): @@ -147,16 +152,24 @@ def _calculate_comb_sum(self, results): return scores_map - def _calculate_rrf(self, results): + def _calculate_rrf(self, results: List[List[Document]]) -> Dict[str, float]: """ Calculates the reciprocal rank fusion. The constant K is set to 61 (60 was suggested by the original paper, plus 1 as python lists are 0-based and the paper used 1-based ranking). """ K = 61 - scores_map = defaultdict(int) - for result in results: + scores_map: Dict[str, float] = defaultdict(float) + weights = self.weights if self.weights else [1 / len(results)] * len(results) + + # Calculate weighted reciprocal rank fusion score + for result, weight in zip(results, weights): for rank, doc in enumerate(result): - scores_map[doc.id] += 1 / (K + rank) + scores_map[doc.id] += (weight * len(results)) / (K + rank) + + # Normalize scores. Note: len(results) / K is the maximum possible score, + # achieved by being ranked first in all results with non-zero weight. + for id in scores_map: + scores_map[id] = scores_map[id] / (len(results) / K) return scores_map diff --git a/haystack/nodes/other/route_documents.py b/haystack/nodes/other/route_documents.py index e504d05e29..cd4ddfee46 100644 --- a/haystack/nodes/other/route_documents.py +++ b/haystack/nodes/other/route_documents.py @@ -20,7 +20,7 @@ class RouteDocuments(BaseComponent): def __init__( self, split_by: str = "content_type", - metadata_values: Optional[Union[List[str], List[List[str]]]] = None, + metadata_values: Optional[Union[List[Union[str, bool, int]], List[List[Union[str, bool, int]]]]] = None, return_remaining: bool = False, ): """ diff --git a/haystack/nodes/other/shaper.py b/haystack/nodes/other/shaper.py index b367969dcb..2b6ea0cc7f 100644 --- a/haystack/nodes/other/shaper.py +++ b/haystack/nodes/other/shaper.py @@ -745,7 +745,11 @@ def run( # type: ignore meta: Optional[dict] = None, invocation_context: Optional[Dict[str, Any]] = None, ) -> Tuple[Dict, str]: - invocation_context = invocation_context or {} + if invocation_context is None: + invocation_context = {} + else: + invocation_context = invocation_context.copy() + if query and "query" not in invocation_context.keys(): invocation_context["query"] = query @@ -755,7 +759,7 @@ def run( # type: ignore if labels and "labels" not in invocation_context.keys(): invocation_context["labels"] = labels - if documents != None and "documents" not in invocation_context.keys(): + if documents is not None and "documents" not in invocation_context.keys(): invocation_context["documents"] = documents if meta and "meta" not in invocation_context.keys(): diff --git a/haystack/nodes/preprocessor/preprocessor.py b/haystack/nodes/preprocessor/preprocessor.py index 9dddc22e7f..9d76909bda 100644 --- a/haystack/nodes/preprocessor/preprocessor.py +++ b/haystack/nodes/preprocessor/preprocessor.py @@ -28,6 +28,7 @@ with LazyImport("Run 'pip install farm-haystack[preprocessing]' or 'pip install nltk'") as nltk_import: import nltk + from nltk.tokenize.punkt import PunktTokenizer iso639_to_nltk = { "ru": "russian", @@ -59,7 +60,7 @@ def __init__( clean_header_footer: bool = False, clean_empty_lines: bool = True, remove_substrings: Optional[List[str]] = None, - split_by: Optional[Literal["token", "word", "sentence", "passage"]] = "word", + split_by: Optional[Literal["token", "word", "sentence", "passage", "page"]] = "word", split_length: int = 200, split_overlap: int = 0, split_respect_sentence_boundary: bool = True, @@ -79,7 +80,7 @@ def __init__( :param clean_whitespace: Strip whitespaces before or after each line in the text. :param clean_empty_lines: Remove more than two empty lines in the text. :param remove_substrings: Remove specified substrings from the text. If no value is provided an empty list is created by default. - :param split_by: Unit for splitting the document. Can be "word", "sentence", or "passage". Set to None to disable splitting. + :param split_by: Unit for splitting the document. Can be "token", "word", "sentence", "passage", or "page". Set to None to disable splitting. :param split_length: Max. number of the above split unit (e.g. words) that are allowed in one document. For instance, if n -> 10 & split_by -> "sentence", then each output document will have 10 sentences. :param split_overlap: Word overlap between two adjacent documents after a split. @@ -120,10 +121,18 @@ def __init__( nltk.data.find("tokenizers/punkt") except LookupError: try: - nltk.download("punkt") + nltk.download("punkt_tab") except FileExistsError as error: logger.debug("NLTK punkt tokenizer seems to be already downloaded. Error message: %s", error) pass + + if tokenizer_model_folder is not None: + warnings.warn( + "Custom NLTK tokenizers are no longer allowed. " + "The 'tokenizer_model_folder' parameter will be ignored. " + "Please use the built-in nltk tokenizers instead by specifying the `language` parameter." + ) + self.tokenizer_model_folder = None self.clean_whitespace = clean_whitespace self.clean_header_footer = clean_header_footer self.clean_empty_lines = clean_empty_lines @@ -134,7 +143,6 @@ def __init__( self.split_respect_sentence_boundary = split_respect_sentence_boundary self.tokenizer = tokenizer self.language = language - self.tokenizer_model_folder = tokenizer_model_folder self.print_log: Set[str] = set() self.id_hash_keys = id_hash_keys self.progress_bar = progress_bar @@ -148,7 +156,7 @@ def process( clean_header_footer: Optional[bool] = None, clean_empty_lines: Optional[bool] = None, remove_substrings: Optional[List[str]] = None, - split_by: Optional[Literal["token", "word", "sentence", "passage"]] = None, + split_by: Optional[Literal["token", "word", "sentence", "passage", "page"]] = None, split_length: Optional[int] = None, split_overlap: Optional[int] = None, split_respect_sentence_boundary: Optional[bool] = None, @@ -230,7 +238,7 @@ def _process_single( clean_header_footer: Optional[bool] = None, clean_empty_lines: Optional[bool] = None, remove_substrings: Optional[List[str]] = None, - split_by: Optional[Literal["token", "word", "sentence", "passage"]] = None, + split_by: Optional[Literal["token", "word", "sentence", "passage", "page"]] = None, split_length: Optional[int] = None, split_overlap: Optional[int] = None, split_respect_sentence_boundary: Optional[bool] = None, @@ -347,7 +355,7 @@ def clean( def split( self, document: Union[dict, Document], - split_by: Optional[Literal["token", "word", "sentence", "passage"]], + split_by: Optional[Literal["token", "word", "sentence", "passage", "page"]], split_length: int, split_overlap: int, split_respect_sentence_boundary: bool, @@ -588,7 +596,9 @@ def _get_overlap_from_slice( return processed_sents, next_slice, word_count_slice - def _split_into_units(self, text: str, split_by: str, tokenizer: Any) -> Tuple[List[str], str]: + def _split_into_units( + self, text: str, split_by: Literal["token", "word", "sentence", "passage", "page"], tokenizer: Any + ) -> Tuple[List[str], str]: if split_by == "passage": elements = text.split("\n\n") split_at = "\n\n" @@ -601,9 +611,12 @@ def _split_into_units(self, text: str, split_by: str, tokenizer: Any) -> Tuple[L elif split_by == "token": elements = self._split_tokens(text, tokenizer) split_at = "" + elif split_by == "page": + elements = text.split("\f") + split_at = "\f" else: raise NotImplementedError( - "PreProcessor only supports 'passage', 'sentence', 'word' or 'token' split_by options." + "PreProcessor only supports 'passage', 'sentence', 'word', 'token' or 'page' split_by options." ) return elements, split_at @@ -631,8 +644,14 @@ def _concatenate_units( processed_units = current_units[: split_length - split_overlap] cur_start_idx += len((split_at_len * " ").join(processed_units)) + split_at_len if self.add_page_number: - num_page_breaks = sum(processed_unit.count("\f") for processed_unit in processed_units) + if split_at != "\f": + num_page_breaks = sum(processed_unit.count("\f") for processed_unit in processed_units) + else: + num_page_breaks = len(processed_units) cur_page += num_page_breaks + else: + if self.add_page_number and split_at == "\f": + cur_page += 1 return text_splits, splits_pages, splits_start_idxs @@ -911,14 +930,14 @@ def _load_sentence_tokenizer(self, language_name: Optional[str]) -> "nltk.tokeni # Use a default NLTK model elif language_name is not None: - sentence_tokenizer = nltk.data.load(f"tokenizers/punkt/{language_name}.pickle") + sentence_tokenizer = PunktTokenizer(language_name) else: logger.error( "PreProcessor couldn't find the default sentence tokenizer model for %s. " " Using English instead. You may train your own model and use the 'tokenizer_model_folder' parameter.", self.language, ) - sentence_tokenizer = nltk.data.load("tokenizers/punkt/english.pickle") + sentence_tokenizer = PunktTokenizer() # default english model return sentence_tokenizer diff --git a/haystack/nodes/prompt/invocation_layer/amazon_bedrock.py b/haystack/nodes/prompt/invocation_layer/amazon_bedrock.py index 51f88b14c4..b9b8ea7a6c 100644 --- a/haystack/nodes/prompt/invocation_layer/amazon_bedrock.py +++ b/haystack/nodes/prompt/invocation_layer/amazon_bedrock.py @@ -78,23 +78,47 @@ class AnthropicClaudeAdapter(BedrockModelAdapter): Model adapter for the Anthropic's Claude model. """ - def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: - default_params = { - "max_tokens_to_sample": self.max_length, - "stop_sequences": ["\n\nHuman:"], - "temperature": None, - "top_p": None, - "top_k": None, - } - params = self._get_params(inference_kwargs, default_params) + def __init__(self, model_kwargs: Dict[str, Any], max_length: Optional[int]) -> None: + self.use_messages_api = model_kwargs.get("use_messages_api", True) + super().__init__(model_kwargs, max_length) - body = {"prompt": f"\n\nHuman: {prompt}\n\nAssistant:", **params} + def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: + if self.use_messages_api: + default_params: Dict[str, Any] = { + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": self.max_length, + "system": None, + "stop_sequences": None, + "temperature": None, + "top_p": None, + "top_k": None, + } + params = self._get_params(inference_kwargs, default_params) + + body = {"messages": [{"role": "user", "content": prompt}], **params} + else: + default_params = { + "max_tokens_to_sample": self.max_length, + "stop_sequences": ["\n\nHuman:"], + "temperature": None, + "top_p": None, + "top_k": None, + } + params = self._get_params(inference_kwargs, default_params) + + body = {"prompt": f"\n\nHuman: {prompt}\n\nAssistant:", **params} return body def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]: + if self.use_messages_api: + return [content["text"] for content in response_body["content"]] + return [response_body["completion"]] def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + if self.use_messages_api: + return chunk.get("delta", {}).get("text", "") + return chunk.get("completion", "") @@ -197,6 +221,33 @@ def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: return chunk.get("generation", "") +class MistralAIAdapter(BedrockModelAdapter): + """ + Model adapter for the Mistral's AI models. + """ + + def prepare_body(self, prompt: str, **inference_kwargs: Any) -> Dict[str, Any]: + default_params = { + "max_tokens": self.max_length, + "stop": None, + "temperature": None, + "top_p": None, + "top_k": None, + } + params = self._get_params(inference_kwargs, default_params) + + body = {"prompt": prompt, **params} + return body + + def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]: + return [output["text"] for output in response_body["outputs"]] + + def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + outputs: List[Dict[str, str]] = chunk.get("outputs", []) + output = next(iter(outputs), {}) + return output.get("text", "") + + class AmazonBedrockInvocationLayer(AWSBaseInvocationLayer): """ Invocation layer for Amazon Bedrock models. @@ -207,7 +258,8 @@ class AmazonBedrockInvocationLayer(AWSBaseInvocationLayer): r"ai21.j2.*": AI21LabsJurassic2Adapter, r"cohere.command.*": CohereCommandAdapter, r"anthropic.claude.*": AnthropicClaudeAdapter, - r"meta.llama2.*": MetaLlama2ChatAdapter, + r"meta.llama.*": MetaLlama2ChatAdapter, + r"mistral.mi[sx]tral.*": MistralAIAdapter, # codespell:ignore tral } def __init__( diff --git a/haystack/nodes/prompt/invocation_layer/anthropic_claude.py b/haystack/nodes/prompt/invocation_layer/anthropic_claude.py index edd1c6c28c..539d0f9ded 100644 --- a/haystack/nodes/prompt/invocation_layer/anthropic_claude.py +++ b/haystack/nodes/prompt/invocation_layer/anthropic_claude.py @@ -22,9 +22,10 @@ # Taken from: # https://github.com/anthropics/anthropic-sdk-python/blob/main/anthropic/tokenizer.py#L7 # This is a JSON config to load the tokenizer used for Anthropic Claude. -CLAUDE_TOKENIZER_REMOTE_FILE = ( - "https://raw.githubusercontent.com/anthropics/anthropic-sdk-python/main/src/anthropic/tokenizer.json" -) +# Anthropic removed tokenizer.json from their repo (https://github.com/anthropics/anthropic-sdk-python/pull/726), +# we need to use the commit from the latest version of the SDK that still +# has it, i.e. 0.38.0 and commit hash 14afc93ffd809e60666a267763a57a328184c5e4. +CLAUDE_TOKENIZER_REMOTE_FILE = "https://raw.githubusercontent.com/anthropics/anthropic-sdk-python/14afc93ffd809e60666a267763a57a328184c5e4/src/anthropic/tokenizer.json" class AnthropicClaudeInvocationLayer(PromptModelInvocationLayer): diff --git a/haystack/nodes/prompt/invocation_layer/azure_open_ai.py b/haystack/nodes/prompt/invocation_layer/azure_open_ai.py index d10dc65463..001d6da8ba 100644 --- a/haystack/nodes/prompt/invocation_layer/azure_open_ai.py +++ b/haystack/nodes/prompt/invocation_layer/azure_open_ai.py @@ -19,7 +19,7 @@ def __init__( azure_deployment_name: str, api_key: str, api_version: str = "2022-12-01", - model_name_or_path: str = "text-davinci-003", + model_name_or_path: str = "gpt-3.5-turbo-instruct", max_length: Optional[int] = 100, **kwargs, ): @@ -42,7 +42,7 @@ def supports(cls, model_name_or_path: str, **kwargs) -> bool: Ensures Azure OpenAI Invocation Layer is selected when `azure_base_url` and `azure_deployment_name` are provided in addition to a list of supported models. """ - valid_model = model_name_or_path in ["ada", "babbage", "davinci", "curie"] or any( + valid_model = model_name_or_path in ["ada", "babbage", "davinci", "curie", "gpt-3.5-turbo-instruct"] or any( m in model_name_or_path for m in ["-ada-", "-babbage-", "-davinci-", "-curie-"] ) return valid_model and has_azure_parameters(**kwargs) diff --git a/haystack/nodes/prompt/invocation_layer/handlers.py b/haystack/nodes/prompt/invocation_layer/handlers.py index 073561b3f4..11892cb43e 100644 --- a/haystack/nodes/prompt/invocation_layer/handlers.py +++ b/haystack/nodes/prompt/invocation_layer/handlers.py @@ -1,5 +1,5 @@ from abc import abstractmethod, ABC -from typing import Union, Dict +from typing import Optional, Union, Dict from haystack.lazy_imports import LazyImport @@ -61,8 +61,14 @@ class DefaultPromptHandler: are within the model_max_length. """ - def __init__(self, model_name_or_path: str, model_max_length: int, max_length: int = 100): - self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + def __init__( + self, + model_name_or_path: str, + model_max_length: int, + max_length: int = 100, + use_auth_token: Optional[Union[str, bool]] = None, + ): + self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, token=use_auth_token) self.tokenizer.model_max_length = model_max_length self.model_max_length = model_max_length self.max_length = max_length diff --git a/haystack/nodes/prompt/invocation_layer/hugging_face.py b/haystack/nodes/prompt/invocation_layer/hugging_face.py index 83a1cfc9f2..45fba2157b 100644 --- a/haystack/nodes/prompt/invocation_layer/hugging_face.py +++ b/haystack/nodes/prompt/invocation_layer/hugging_face.py @@ -27,6 +27,7 @@ ) from haystack.modeling.utils import initialize_device_settings # pylint: disable=ungrouped-imports from haystack.nodes.prompt.invocation_layer.handlers import HFTokenStreamingHandler + from haystack.utils.torch_utils import resolve_torch_dtype class StopWordsCriteria(StoppingCriteria): """ @@ -177,13 +178,14 @@ def _prepare_pipeline_kwargs(self, **kwargs) -> Dict[str, Any]: device_map = kwargs.get("device_map", None) device = kwargs.get("device") if device_map is None else None # prepare torch_dtype for pipeline invocation - torch_dtype = self._extract_torch_dtype(**kwargs) + torch_dtype = resolve_torch_dtype(kwargs.get("torch_dtype")) # and the model (prefer model instance over model_name_or_path str identifier) model = kwargs.get("model") or kwargs.get("model_name_or_path") trust_remote_code = kwargs.get("trust_remote_code", False) hub_kwargs = { "revision": kwargs.get("revision", None), - "use_auth_token": kwargs.get("use_auth_token", None), + # use_auth_token is no longer used in HuggingFace pipelines. We convert it to token + "token": kwargs.get("use_auth_token", None), "trust_remote_code": trust_remote_code, } model_kwargs = kwargs.get("model_kwargs", {}) diff --git a/haystack/nodes/prompt/invocation_layer/hugging_face_inference.py b/haystack/nodes/prompt/invocation_layer/hugging_face_inference.py index 9fd7989b3f..a060d11377 100644 --- a/haystack/nodes/prompt/invocation_layer/hugging_face_inference.py +++ b/haystack/nodes/prompt/invocation_layer/hugging_face_inference.py @@ -39,7 +39,14 @@ class HFInferenceEndpointInvocationLayer(PromptModelInvocationLayer): """ - def __init__(self, api_key: str, model_name_or_path: str, max_length: Optional[int] = 100, **kwargs): + def __init__( + self, + api_key: str, + model_name_or_path: str, + max_length: Optional[int] = 100, + use_auth_token: Optional[Union[str, bool]] = None, + **kwargs, + ): """ Creates an instance of HFInferenceEndpointInvocationLayer :param model_name_or_path: can be either: @@ -76,6 +83,7 @@ def __init__(self, api_key: str, model_name_or_path: str, max_length: Optional[i "repetition_penalty", "return_full_text", "seed", + "stop", "stream", "stream_handler", "temperature", @@ -104,6 +112,7 @@ def __init__(self, api_key: str, model_name_or_path: str, max_length: Optional[i model_name_or_path=model_name_or_path, model_max_length=model_max_length, max_length=self.max_length or 100, + use_auth_token=use_auth_token, ) def preprocess_prompt(self, prompt: str): diff --git a/haystack/nodes/prompt/invocation_layer/open_ai.py b/haystack/nodes/prompt/invocation_layer/open_ai.py index 825da26234..ee7f5811f5 100644 --- a/haystack/nodes/prompt/invocation_layer/open_ai.py +++ b/haystack/nodes/prompt/invocation_layer/open_ai.py @@ -33,7 +33,7 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer): def __init__( self, api_key: str, - model_name_or_path: str = "text-davinci-003", + model_name_or_path: str = "gpt-3.5-turbo-instruct", max_length: Optional[int] = 100, api_base: str = "https://api.openai.com/v1", openai_organization: Optional[str] = None, @@ -95,6 +95,8 @@ def __init__( "stream", "stream_handler", "moderate_content", + "seed", + "response_format", ] if key in kwargs } @@ -150,6 +152,12 @@ def _prepare_invoke(self, *args, **kwargs): "frequency_penalty": kwargs_with_defaults.get("frequency_penalty", 0), "logit_bias": kwargs_with_defaults.get("logit_bias", {}), } + response_format = kwargs_with_defaults.get("response_format", None) + if response_format: + base_payload["response_format"] = response_format + seed = kwargs_with_defaults.get("seed", None) + if seed: + base_payload["seed"] = seed return (prompt, base_payload, kwargs_with_defaults, stream, moderation) diff --git a/haystack/nodes/prompt/prompt_model.py b/haystack/nodes/prompt/prompt_model.py index 47ce1ee987..62d28b0d5c 100644 --- a/haystack/nodes/prompt/prompt_model.py +++ b/haystack/nodes/prompt/prompt_model.py @@ -1,5 +1,6 @@ import inspect import logging +import importlib from typing import Any, Dict, List, Optional, Tuple, Type, Union, overload from haystack.nodes.base import BaseComponent @@ -36,11 +37,12 @@ def __init__( model_name_or_path: str = "google/flan-t5-base", max_length: Optional[int] = 100, api_key: Optional[str] = None, + api_base: Optional[str] = None, timeout: Optional[float] = None, use_auth_token: Optional[Union[str, bool]] = None, use_gpu: Optional[bool] = None, devices: Optional[List[Union[str, "torch.device"]]] = None, - invocation_layer_class: Optional[Type[PromptModelInvocationLayer]] = None, + invocation_layer_class: Optional[Union[Type[PromptModelInvocationLayer], str]] = None, model_kwargs: Optional[Dict] = None, ): """ @@ -64,6 +66,7 @@ def __init__( self.model_name_or_path = model_name_or_path self.max_length = max_length self.api_key = api_key + self.api_base = api_base self.timeout = timeout self.use_auth_token = use_auth_token self.use_gpu = use_gpu @@ -73,7 +76,7 @@ def __init__( self.model_invocation_layer = self.create_invocation_layer(invocation_layer_class=invocation_layer_class) def create_invocation_layer( - self, invocation_layer_class: Optional[Type[PromptModelInvocationLayer]] + self, invocation_layer_class: Optional[Union[Type[PromptModelInvocationLayer], str]] ) -> PromptModelInvocationLayer: kwargs = { "api_key": self.api_key, @@ -82,12 +85,27 @@ def create_invocation_layer( "use_gpu": self.use_gpu, "devices": self.devices, } + if self.api_base is not None: + kwargs["api_base"] = self.api_base + all_kwargs = {**self.model_kwargs, **kwargs} - if invocation_layer_class: - return invocation_layer_class( - model_name_or_path=self.model_name_or_path, max_length=self.max_length, **all_kwargs - ) + if isinstance(invocation_layer_class, str): + module_name, class_name = invocation_layer_class.rsplit(".", maxsplit=1) + try: + module = importlib.import_module(module_name) + except ImportError as e: + msg = f"Can't find module {module_name}" + raise ValueError(msg) from e + class_ = getattr(module, class_name) + if class_ is None: + msg = f"Can'f find class {class_name} in module {module_name}" + raise ValueError(msg) + else: + class_ = invocation_layer_class + + if class_: + return class_(model_name_or_path=self.model_name_or_path, max_length=self.max_length, **all_kwargs) for invocation_layer in PromptModelInvocationLayer.invocation_layer_providers: if inspect.isabstract(invocation_layer): diff --git a/haystack/nodes/prompt/prompt_node.py b/haystack/nodes/prompt/prompt_node.py index 92ec069acb..01e04a8811 100644 --- a/haystack/nodes/prompt/prompt_node.py +++ b/haystack/nodes/prompt/prompt_node.py @@ -57,6 +57,7 @@ def __init__( output_variable: Optional[str] = None, max_length: Optional[int] = 100, api_key: Optional[str] = None, + api_base: Optional[str] = None, timeout: Optional[float] = None, use_auth_token: Optional[Union[str, bool]] = None, use_gpu: Optional[bool] = None, @@ -65,6 +66,7 @@ def __init__( top_k: int = 1, debug: Optional[bool] = False, model_kwargs: Optional[Dict] = None, + truncate: bool = True, ): """ Creates a PromptNode instance. @@ -82,6 +84,7 @@ def __init__( :param stop_words: Stops text generation if any of the stop words is generated. :param model_kwargs: Additional keyword arguments passed when loading the model specified in `model_name_or_path`. :param debug: Whether to include the used prompts as debug information in the output under the key _debug. + :param truncate: Whether to truncate the prompt to the maximum token limit before sending it to the model. Note that Azure OpenAI InstructGPT models require two additional parameters: azure_base_url (the URL for the Azure OpenAI API endpoint, usually in the form `https://.openai.azure.com') and @@ -108,12 +111,14 @@ def __init__( self.stop_words: Optional[List[str]] = stop_words self.top_k: int = top_k self.debug = debug + self.truncate = truncate if isinstance(model_name_or_path, str): self.prompt_model = PromptModel( model_name_or_path=model_name_or_path, max_length=max_length, api_key=api_key, + api_base=api_base, timeout=timeout, use_auth_token=use_auth_token, use_gpu=use_gpu, @@ -163,7 +168,8 @@ def prompt(self, prompt_template: Optional[Union[str, PromptTemplate]], *args, * for prompt in template_to_fill.fill(*args, **kwargs): kwargs_copy = template_to_fill.remove_template_params(copy.copy(kwargs)) # and pass the prepared prompt and kwargs copy to the model - prompt = self.prompt_model._ensure_token_limit(prompt) + if self.truncate: + prompt = self.prompt_model._ensure_token_limit(prompt) prompt_collector.append(prompt) logger.debug("Prompt being sent to LLM with prompt %s and kwargs %s", prompt, kwargs_copy) output = self.prompt_model.invoke(prompt, **kwargs_copy) @@ -175,7 +181,8 @@ def prompt(self, prompt_template: Optional[Union[str, PromptTemplate]], *args, * # straightforward prompt, no templates used for prompt in list(args): kwargs_copy = copy.copy(kwargs) - prompt = self.prompt_model._ensure_token_limit(prompt) + if self.truncate: + prompt = self.prompt_model._ensure_token_limit(prompt) prompt_collector.append(prompt) logger.debug("Prompt being sent to LLM with prompt %s and kwargs %s ", prompt, kwargs_copy) output = self.prompt_model.invoke(prompt, **kwargs_copy) @@ -240,21 +247,24 @@ def _prepare( # type: ignore """ Prepare prompt invocation. """ - invocation_context = invocation_context or {} + if invocation_context is None: + invocation_context = {} + else: + invocation_context = invocation_context.copy() - if query and "query" not in invocation_context: + if query is not None and "query" not in invocation_context: invocation_context["query"] = query - if file_paths and "file_paths" not in invocation_context: + if file_paths is not None and "file_paths" not in invocation_context: invocation_context["file_paths"] = file_paths - if labels and "labels" not in invocation_context: + if labels is not None and "labels" not in invocation_context: invocation_context["labels"] = labels - if documents and "documents" not in invocation_context: + if documents is not None and "documents" not in invocation_context: invocation_context["documents"] = documents - if meta and "meta" not in invocation_context: + if meta is not None and "meta" not in invocation_context: invocation_context["meta"] = meta if "prompt_template" not in invocation_context: @@ -343,7 +353,8 @@ async def _aprompt(self, prompt_template: Optional[Union[str, PromptTemplate]], for prompt in template_to_fill.fill(*args, **kwargs): kwargs_copy = template_to_fill.remove_template_params(copy.copy(kwargs)) # and pass the prepared prompt and kwargs copy to the model - prompt = self.prompt_model._ensure_token_limit(prompt) + if self.truncate: + prompt = self.prompt_model._ensure_token_limit(prompt) prompt_collector.append(prompt) logger.debug("Prompt being sent to LLM with prompt %s and kwargs %s", prompt, kwargs_copy) output = await self.prompt_model.ainvoke(prompt, **kwargs_copy) @@ -355,7 +366,8 @@ async def _aprompt(self, prompt_template: Optional[Union[str, PromptTemplate]], # straightforward prompt, no templates used for prompt in list(args): kwargs_copy = copy.copy(kwargs) - prompt = self.prompt_model._ensure_token_limit(prompt) + if self.truncate: + prompt = self.prompt_model._ensure_token_limit(prompt) prompt_collector.append(prompt) logger.debug("Prompt being sent to LLM with prompt %s and kwargs %s ", prompt, kwargs_copy) output = await self.prompt_model.ainvoke(prompt, **kwargs_copy) diff --git a/haystack/nodes/prompt/prompt_template.py b/haystack/nodes/prompt/prompt_template.py index f90ccc17ad..4d4c73febb 100644 --- a/haystack/nodes/prompt/prompt_template.py +++ b/haystack/nodes/prompt/prompt_template.py @@ -539,8 +539,8 @@ def post_process(self, prompt_output: List[str], **kwargs) -> List[Any]: if self.output_parser: invocation_context = kwargs invocation_context["results"] = prompt_output - self.output_parser.run(invocation_context=invocation_context) - return invocation_context[self.output_parser.outputs[0]] + parser_results, _ = self.output_parser.run(invocation_context=invocation_context) + return parser_results[self.output_parser.outputs[0]] else: return prompt_output diff --git a/haystack/nodes/ranker/sentence_transformers.py b/haystack/nodes/ranker/sentence_transformers.py index 17d82055e5..96ea6184f7 100644 --- a/haystack/nodes/ranker/sentence_transformers.py +++ b/haystack/nodes/ranker/sentence_transformers.py @@ -18,6 +18,7 @@ from torch.nn import DataParallel from transformers import AutoModelForSequenceClassification, AutoTokenizer from haystack.modeling.utils import initialize_device_settings # pylint: disable=ungrouped-imports + from haystack.utils.torch_utils import resolve_torch_dtype class SentenceTransformersRanker(BaseRanker): @@ -57,6 +58,7 @@ def __init__( progress_bar: bool = True, use_auth_token: Optional[Union[str, bool]] = None, embed_meta_fields: Optional[List[str]] = None, + model_kwargs: Optional[dict] = None, ): """ :param model_name_or_path: Directory of a saved model or the name of a public model e.g. @@ -90,8 +92,16 @@ def __init__( self.devices, _ = initialize_device_settings(devices=devices, use_cuda=use_gpu, multi_gpu=True) self.progress_bar = progress_bar + self.model_kwargs = model_kwargs + kwargs = model_kwargs if model_kwargs else {} + torch_dtype = resolve_torch_dtype(kwargs.get("torch_dtype")) + if torch_dtype: + kwargs["torch_dtype"] = torch_dtype self.transformer_model = AutoModelForSequenceClassification.from_pretrained( - pretrained_model_name_or_path=model_name_or_path, revision=model_version, use_auth_token=use_auth_token + pretrained_model_name_or_path=model_name_or_path, + revision=model_version, + use_auth_token=use_auth_token, + **kwargs, ) self.transformer_model.to(str(self.devices[0])) self.transformer_tokenizer = AutoTokenizer.from_pretrained( diff --git a/haystack/nodes/reader/farm.py b/haystack/nodes/reader/farm.py index 94847339d1..31d4f07cf7 100644 --- a/haystack/nodes/reader/farm.py +++ b/haystack/nodes/reader/farm.py @@ -38,6 +38,7 @@ from haystack.modeling.training import Trainer, DistillationTrainer, TinyBERTDistillationTrainer from haystack.modeling.evaluation import Evaluator from haystack.modeling.utils import set_all_seeds, initialize_device_settings + from haystack.utils.torch_utils import resolve_torch_dtype class FARMReader(BaseReader): @@ -77,6 +78,7 @@ def __init__( use_auth_token: Optional[Union[str, bool]] = None, max_query_length: int = 64, preprocessing_batch_size: Optional[int] = None, + model_kwargs: Optional[dict] = None, ): """ :param model_name_or_path: Directory of a saved model or the name of a public model e.g. 'bert-base-cased', @@ -149,6 +151,10 @@ def __init__( self.return_no_answers = return_no_answer self.top_k = top_k self.top_k_per_candidate = top_k_per_candidate + kwargs = model_kwargs if model_kwargs else {} + torch_dtype = resolve_torch_dtype(kwargs.get("torch_dtype")) + if torch_dtype: + kwargs["torch_dtype"] = torch_dtype self.inferencer = QAInferencer.load( model_name_or_path, batch_size=batch_size, @@ -166,6 +172,7 @@ def __init__( devices=self.devices, # type: ignore [arg-type] use_auth_token=use_auth_token, max_query_length=max_query_length, + **kwargs, ) self.inferencer.model.prediction_heads[0].context_window_size = context_window_size self.inferencer.model.prediction_heads[0].no_ans_boost = no_ans_boost @@ -936,11 +943,33 @@ def predict(self, query: str, documents: List[Document], top_k: Optional[int] = predictions = self._deduplicate_predictions(predictions, documents) # assemble answers from all the different documents & format them. answers, max_no_ans_gap = self._extract_answers_of_predictions(predictions, top_k) + answers = [self._add_answer_page_number(documents=documents, answer=answer) for answer in answers] # TODO: potentially simplify return here to List[Answer] and handle no_ans_gap differently result = {"query": query, "no_ans_gap": max_no_ans_gap, "answers": answers} return result + def _add_answer_page_number(self, documents: List[Document], answer: Answer) -> Answer: + # Following the implementation of BaseReader.add_doc_meta_data_to_answer + if answer.meta is None: + answer.meta = {} + + if answer.offsets_in_document is None: + return answer + + # Calculate the answer page number + meta_to_add = {} + if answer.document_ids: + for doc in documents: + if doc.id in answer.document_ids and ("page_number" in doc.meta): + ans_start = answer.offsets_in_document[0].start # type: ignore + answer_page_number = doc.meta["page_number"] + doc.content[:ans_start].count("\f") + meta_to_add = {"answer_page_number": answer_page_number} + break + + answer.meta.update(meta_to_add) + return answer + def eval_on_file( self, data_dir: Union[Path, str], diff --git a/haystack/nodes/retriever/_embedding_encoder.py b/haystack/nodes/retriever/_embedding_encoder.py index d07e7ab1b8..d05af6ee45 100644 --- a/haystack/nodes/retriever/_embedding_encoder.py +++ b/haystack/nodes/retriever/_embedding_encoder.py @@ -16,7 +16,7 @@ HAYSTACK_REMOTE_API_MAX_RETRIES, HAYSTACK_REMOTE_API_TIMEOUT_SEC, ) -from haystack.errors import CohereError, CohereUnauthorizedError +from haystack.errors import AWSConfigurationError, CohereError, CohereUnauthorizedError from haystack.nodes.retriever._openai_encoder import _OpenAIEmbeddingEncoder from haystack.schema import Document from haystack.telemetry import send_event @@ -42,6 +42,9 @@ from haystack.modeling.infer import Inferencer from haystack.nodes.retriever._losses import _TRAINING_LOSSES +with LazyImport(message="Run 'pip install boto3'") as boto3_import: + import boto3 + from botocore.exceptions import BotoCoreError COHERE_TIMEOUT = float(os.environ.get(HAYSTACK_REMOTE_API_TIMEOUT_SEC, 30)) COHERE_BACKOFF = int(os.environ.get(HAYSTACK_REMOTE_API_BACKOFF_SEC, 10)) @@ -55,6 +58,8 @@ "embed-multilingual-v2.0", ] +BEDROCK_EMBEDDING_MODELS = ["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"] + class _DefaultEmbeddingEncoder(_BaseEmbeddingEncoder): def __init__(self, retriever: "EmbeddingRetriever"): @@ -130,7 +135,10 @@ def __init__(self, retriever: "EmbeddingRetriever"): # e.g. 'roberta-base-nli-stsb-mean-tokens' torch_and_transformers_import.check() self.embedding_model = SentenceTransformer( - retriever.embedding_model, device=str(retriever.devices[0]), use_auth_token=retriever.use_auth_token + retriever.embedding_model, + device=str(retriever.devices[0]), + use_auth_token=retriever.use_auth_token, + revision=retriever.model_version, ) self.batch_size = retriever.batch_size self.embedding_model.max_seq_length = retriever.max_seq_len @@ -434,6 +442,107 @@ def save(self, save_dir: Union[Path, str]): raise NotImplementedError(f"Saving is not implemented for {self.__class__}") +class _BedrockEmbeddingEncoder(_BaseEmbeddingEncoder): + def __init__(self, retriever: "EmbeddingRetriever"): + """Embedding Encoder for Bedrock models + See https://docs.aws.amazon.com/bedrock/latest/userguide/embeddings.html for more details. + The maximum input text is 8K tokens and the maximum output vector length is 1536. + Titan embeddings do not support batch operations. + + :param retriever: EmbeddingRetriever object + """ + boto3_import.check() + if retriever.embedding_model not in BEDROCK_EMBEDDING_MODELS: + raise ValueError("Model not supported by Bedrock Embedding Encoder") + self.model = retriever.embedding_model + self.client = self._initialize_boto3_session(retriever.aws_config).client("bedrock-runtime") + + def _initialize_boto3_session(self, aws_config: Optional[Dict[str, Any]]): + if aws_config is None: + raise ValueError( + "`aws_config` is not set. To use Bedrock models, you should set `aws_config` when initializing the retriever." + ) + + aws_access_key_id = aws_config.get("aws_access_key_id", None) + aws_secret_access_key = aws_config.get("aws_secret_access_key", None) + aws_session_token = aws_config.get("aws_session_token", None) + region_name = aws_config.get("region_name", None) + profile_name = aws_config.get("profile_name", None) + try: + return boto3.Session( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + region_name=region_name, + profile_name=profile_name, + ) + except BotoCoreError as e: + raise AWSConfigurationError( + f"Failed to initialize the session with provided AWS credentials {aws_config}" + ) from e + + def _embed_batch_cohere( + self, texts: List[str], input_type: Literal["search_query", "search_document"] + ) -> np.ndarray: + cohere_payload = {"texts": texts, "input_type": input_type, "truncate": "RIGHT"} + response = self._invoke_model(cohere_payload) + embeddings = np.array(response["embeddings"]) + return embeddings + + def _embed_titan(self, text: str) -> np.ndarray: + titan_payload = {"inputText": text} + response = self._invoke_model(titan_payload) + embeddings = np.array(response["embedding"]) + return embeddings + + def _invoke_model(self, payload: Dict[str, Any]) -> Dict[str, Any]: + body = json.dumps(payload) + response = self.client.invoke_model( + body=body, modelId=self.model, accept="application/json", contentType="application/json" + ) + body = response.get("body").read().decode("utf-8") + response_body = json.loads(body) + return response_body + + def embed_queries(self, queries: List[str]) -> np.ndarray: + if self.model == "amazon.titan-embed-text-v1": + all_embeddings = [] + for query in queries: + generated_embeddings = self._embed_titan(query) + all_embeddings.append(generated_embeddings) + return np.stack(all_embeddings) + else: + return self._embed_batch_cohere(queries, input_type="search_query") + + def embed_documents(self, docs: List[Document]) -> np.ndarray: + if self.model == "amazon.titan-embed-text-v1": + all_embeddings = [] + for doc in docs: + generated_embeddings = self._embed_titan(doc.content) + all_embeddings.append(generated_embeddings) + return np.stack(all_embeddings) + else: + contents = [d.content for d in docs] + return self._embed_batch_cohere(contents, input_type="search_document") + + def train( + self, + training_data: List[Dict[str, Any]], + learning_rate: float = 2e-5, + n_epochs: int = 1, + num_warmup_steps: Optional[int] = None, + batch_size: int = 16, + train_loss: Literal["mnrl", "margin_mse"] = "mnrl", + num_workers: int = 0, + use_amp: bool = False, + **kwargs, + ): + raise NotImplementedError(f"Training is not implemented for {self.__class__}") + + def save(self, save_dir: Union[Path, str]): + raise NotImplementedError(f"Saving is not implemented for {self.__class__}") + + _EMBEDDING_ENCODERS: Dict[str, Callable] = { "farm": _DefaultEmbeddingEncoder, "transformers": _DefaultEmbeddingEncoder, @@ -441,4 +550,5 @@ def save(self, save_dir: Union[Path, str]): "retribert": _RetribertEmbeddingEncoder, "openai": _OpenAIEmbeddingEncoder, "cohere": _CohereEmbeddingEncoder, + "bedrock": _BedrockEmbeddingEncoder, } diff --git a/haystack/nodes/retriever/_openai_encoder.py b/haystack/nodes/retriever/_openai_encoder.py index 4af3cd1fa8..f66e3191c5 100644 --- a/haystack/nodes/retriever/_openai_encoder.py +++ b/haystack/nodes/retriever/_openai_encoder.py @@ -44,9 +44,9 @@ def __init__(self, retriever: "EmbeddingRetriever"): self.batch_size = min(64, retriever.batch_size) self.progress_bar = retriever.progress_bar model_class: str = next( - (m for m in ["ada", "babbage", "davinci", "curie"] if m in retriever.embedding_model), "babbage" + (m for m in ["3-large", "3-small", "ada", "babbage", "davinci", "curie"] if m in retriever.embedding_model), + "babbage", ) - tokenizer = self._setup_encoding_models(model_class, retriever.embedding_model, retriever.max_seq_len) self._tokenizer = load_openai_tokenizer(tokenizer_name=tokenizer) @@ -57,7 +57,7 @@ def _setup_encoding_models(self, model_class: str, model_name: str, max_seq_len: tokenizer_name = "gpt2" # new generation of embedding models (December 2022), we need to specify the full name - if model_name.endswith("-002"): + if model_name.endswith("-002") or model_name in ["text-embedding-3-large", "text-embedding-3-small"]: self.query_encoder_model = model_name self.doc_encoder_model = model_name self.max_seq_len = min(8191, max_seq_len) diff --git a/haystack/nodes/retriever/dense.py b/haystack/nodes/retriever/dense.py index d4663e6cfd..6678aca939 100644 --- a/haystack/nodes/retriever/dense.py +++ b/haystack/nodes/retriever/dense.py @@ -17,7 +17,11 @@ from haystack.schema import Document, FilterType from haystack.document_stores import BaseDocumentStore from haystack.nodes.retriever.base import BaseRetriever -from haystack.nodes.retriever._embedding_encoder import _EMBEDDING_ENCODERS, COHERE_EMBEDDING_MODELS +from haystack.nodes.retriever._embedding_encoder import ( + _EMBEDDING_ENCODERS, + COHERE_EMBEDDING_MODELS, + BEDROCK_EMBEDDING_MODELS, +) from haystack.utils.early_stopping import EarlyStopping from haystack.telemetry import send_event from haystack.lazy_imports import LazyImport @@ -1468,13 +1472,15 @@ def __init__( azure_deployment_name: Optional[str] = None, api_base: str = "https://api.openai.com/v1", openai_organization: Optional[str] = None, + aws_config: Optional[Dict[str, Any]] = None, ): """ :param document_store: An instance of DocumentStore from which to retrieve documents. :param embedding_model: Local path or name of model in Hugging Face's model hub such as ``'sentence-transformers/all-MiniLM-L6-v2'``. The embedding model could also potentially be an OpenAI model ["ada", "babbage", "davinci", "curie"] or - a Cohere model ["embed-english-v2.0", "embed-english-light-v2.0", "embed-multilingual-v2.0"]. + a Cohere model ["embed-english-v2.0", "embed-english-light-v2.0", "embed-multilingual-v2.0"] or + an AWS Bedrock model ["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"]. :param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash. :param use_gpu: Whether to use all available GPUs or the CPU. Falls back on CPU if no GPU is available. :param batch_size: Number of documents to encode at once. @@ -1489,6 +1495,7 @@ def __init__( 4. `retribert` : (will use `_RetribertEmbeddingEncoder` as embedding encoder) 5. `openai` : (will use `_OpenAIEmbeddingEncoder` as embedding encoder) 6. `cohere` : (will use `_CohereEmbeddingEncoder` as embedding encoder) + 7. `bedrock` : (will use `_BedrockEmbeddingEncoder` as embedding encoder) :param pooling_strategy: Strategy for combining the embeddings from the model (for farm / transformers models only). Options: @@ -1533,6 +1540,7 @@ def __init__( :param api_base: The OpenAI API base URL, defaults to `"https://api.openai.com/v1"`. :param openai_organization: The OpenAI-Organization ID, defaults to `None`. For more details, see OpenAI [documentation](https://platform.openai.com/docs/api-reference/requesting-organization). + :param aws_config: The aws_config contains {aws_access_key, aws_secret_key, aws_region, profile_name} to use with the boto3 Session for an AWS Bedrock retriever. Defaults to 'None'. """ torch_and_transformers_import.check() @@ -1565,6 +1573,7 @@ def __init__( self.azure_base_url = azure_base_url self.azure_deployment_name = azure_deployment_name self.openai_organization = openai_organization + self.aws_config = aws_config self.model_format = ( self._infer_model_format(model_name_or_path=embedding_model, use_auth_token=use_auth_token) if model_format is None @@ -1885,13 +1894,21 @@ def _preprocess_documents(self, docs: List[Document]) -> List[Document]: @staticmethod def _infer_model_format(model_name_or_path: str, use_auth_token: Optional[Union[str, bool]]) -> str: - valid_openai_model_name = model_name_or_path in ["ada", "babbage", "davinci", "curie"] or any( - m in model_name_or_path for m in ["-ada-", "-babbage-", "-davinci-", "-curie-"] - ) + # pylint: disable=too-many-return-statements + valid_openai_model_name = model_name_or_path in [ + "ada", + "babbage", + "davinci", + "curie", + "text-embedding-3-small", + "text-embedding-3-large", + ] or any(m in model_name_or_path for m in ["-ada-", "-babbage-", "-davinci-", "-curie-"]) if valid_openai_model_name: return "openai" if model_name_or_path in COHERE_EMBEDDING_MODELS: return "cohere" + if model_name_or_path in BEDROCK_EMBEDDING_MODELS: + return "bedrock" # Check if model name is a local directory with sentence transformers config file in it if Path(model_name_or_path).exists(): if Path(f"{model_name_or_path}/config_sentence_transformers.json").exists(): diff --git a/haystack/nodes/retriever/link_content.py b/haystack/nodes/retriever/link_content.py index 7f353cb7ff..f49bc0e39f 100644 --- a/haystack/nodes/retriever/link_content.py +++ b/haystack/nodes/retriever/link_content.py @@ -1,5 +1,4 @@ import inspect -import io import logging from collections import defaultdict from datetime import datetime @@ -14,15 +13,11 @@ from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type, RetryCallState from haystack import __version__ -from haystack.lazy_imports import LazyImport from haystack.nodes import PreProcessor, BaseComponent from haystack.schema import Document, MultiLabel logger = logging.getLogger(__name__) -with LazyImport("Run 'pip install farm-haystack[pdf]'") as fitz_import: - import fitz - def html_content_handler(response: Response) -> Optional[str]: """ @@ -34,20 +29,6 @@ def html_content_handler(response: Response) -> Optional[str]: return extractor.get_content(response.text) -def pdf_content_handler(response: Response) -> Optional[str]: - """ - Extracts text from PDF response stream using the PyMuPDF library. - - :param response: Response object from the request. - :return: The extracted text. - """ - file_path = io.BytesIO(response.content) - with fitz.open(stream=file_path, filetype="pdf") as doc: - text = "\f".join([page.get_text() for page in doc]) - - return text.encode("ascii", errors="ignore").decode() - - class LinkContentFetcher(BaseComponent): """ LinkContentFetcher fetches content from a URL and converts it into a list of Document objects. @@ -153,8 +134,6 @@ def __init__( # register default content handlers self._register_content_handler("text/html", html_content_handler) - if fitz_import.is_successful(): - self._register_content_handler("application/pdf", pdf_content_handler) # register custom content handlers, can override default handlers if content_handlers: diff --git a/haystack/nodes/sampler/top_p_sampler.py b/haystack/nodes/sampler/top_p_sampler.py index b77e448760..60d09d83a9 100644 --- a/haystack/nodes/sampler/top_p_sampler.py +++ b/haystack/nodes/sampler/top_p_sampler.py @@ -35,7 +35,7 @@ class TopPSampler(BaseSampler): ```python prompt_node = PromptNode( - "text-davinci-003", + "gpt-3.5-turbo-instruct", api_key=openai_key, max_length=256, default_prompt_template="question-answering-with-document-scores", diff --git a/haystack/nodes/summarizer/transformers.py b/haystack/nodes/summarizer/transformers.py index cad388a36b..35bab664fe 100644 --- a/haystack/nodes/summarizer/transformers.py +++ b/haystack/nodes/summarizer/transformers.py @@ -18,6 +18,7 @@ from transformers import pipeline from haystack.modeling.utils import initialize_device_settings # pylint: disable=ungrouped-imports from haystack.utils.torch_utils import ListDataset + from haystack.utils.torch_utils import resolve_torch_dtype class TransformersSummarizer(BaseSummarizer): @@ -68,6 +69,7 @@ def __init__( progress_bar: bool = True, use_auth_token: Optional[Union[str, bool]] = None, devices: Optional[List[Union[str, "torch.device"]]] = None, + pipeline_kwargs: Optional[dict] = None, ): """ Load a summarization model from transformers. @@ -107,6 +109,10 @@ def __init__( if tokenizer is None: tokenizer = model_name_or_path + kwargs = pipeline_kwargs if pipeline_kwargs else {} + torch_dtype = resolve_torch_dtype(kwargs.get("torch_dtype")) + if torch_dtype: + kwargs["torch_dtype"] = torch_dtype self.summarizer = pipeline( task="summarization", model=model_name_or_path, @@ -114,6 +120,7 @@ def __init__( revision=model_version, device=self.devices[0], use_auth_token=use_auth_token, + **kwargs, ) self.max_length = max_length self.min_length = min_length @@ -121,6 +128,7 @@ def __init__( self.print_log: Set[str] = set() self.batch_size = batch_size self.progress_bar = progress_bar + self.pipeline_kwargs = pipeline_kwargs def predict(self, documents: List[Document]) -> List[Document]: """ @@ -159,6 +167,7 @@ def predict(self, documents: List[Document]) -> List[Document]: return_text=True, clean_up_tokenization_spaces=self.clean_up_tokenization_spaces, truncation=True, + batch_size=self.batch_size, ) result: List[Document] = [] diff --git a/haystack/pipelines/__init__.py b/haystack/pipelines/__init__.py index ca770cd05a..a7a414e79a 100644 --- a/haystack/pipelines/__init__.py +++ b/haystack/pipelines/__init__.py @@ -9,7 +9,6 @@ MostSimilarDocumentsPipeline, QuestionAnswerGenerationPipeline, RetrieverQuestionGenerationPipeline, - GenerativeQAPipeline, ExtractiveQAPipeline, FAQPipeline, TextIndexingPipeline, diff --git a/haystack/pipelines/base.py b/haystack/pipelines/base.py index 0b93b119cc..31b491261c 100644 --- a/haystack/pipelines/base.py +++ b/haystack/pipelines/base.py @@ -39,7 +39,7 @@ from haystack.utils.deepsetcloud import DeepsetCloud from haystack.schema import Answer, EvaluationResult, MultiLabel, Document, Span from haystack.errors import HaystackError, PipelineError, PipelineConfigError, DocumentStoreError -from haystack.nodes import BaseGenerator, Docs2Answers, BaseReader, BaseSummarizer, BaseTranslator, QuestionGenerator +from haystack.nodes import Docs2Answers, BaseReader, BaseSummarizer, BaseTranslator, QuestionGenerator from haystack.nodes.base import BaseComponent, RootNode from haystack.nodes.retriever.base import BaseRetriever from haystack.document_stores.base import BaseDocumentStore @@ -210,7 +210,7 @@ def load_from_deepset_cloud( params["api_key"] = api_key component_config["params"] = params - del pipeline_config["name"] # Would fail validation otherwise + pipeline_config.pop("name", None) # Would fail validation otherwise pipeline = cls.load_from_config( pipeline_config=pipeline_config, pipeline_name=pipeline_name, @@ -596,19 +596,27 @@ def run( # type: ignore **existing_input.get("_debug", {}), **node_output.get("_debug", {}), } - if query: + additional_input = self._combine_node_outputs(existing_input, node_output) + updated_input = {**additional_input, **updated_input} + if query and "query" not in updated_input: updated_input["query"] = query - if file_paths: + if file_paths and "file_paths" not in updated_input: updated_input["file_paths"] = file_paths - if labels: + if labels and "labels" not in updated_input: updated_input["labels"] = labels - if documents: + if documents and "documents" not in updated_input: updated_input["documents"] = documents - if meta: + if meta and "meta" not in updated_input: updated_input["meta"] = meta else: existing_input["inputs"].append(node_output) - updated_input = existing_input + if "_debug" in node_output.keys(): + existing_input["_debug"] = { + **existing_input.get("_debug", {}), + **node_output.get("_debug", {}), + } + additional_input = self._combine_node_outputs(existing_input, node_output) + updated_input = {**additional_input, **existing_input} queue[n] = updated_input else: queue[n] = node_output @@ -618,6 +626,22 @@ def run( # type: ignore return node_output + def _combine_node_outputs(self, existing_input: Dict[str, Any], node_output: Dict[str, Any]) -> Dict[str, Any]: + """ + Combines the outputs of two nodes into a single input for a downstream node. + For matching keys first node's (existing_input) value is kept. This is used for join nodes. + + :param existing_input: The output of the first node. + :param node_output: The output of the second node. + """ + additional_input = {} + combined = {**node_output, **existing_input} + for key in combined: + # Don't overwrite these keys since they are set in Pipeline.run + if key not in ["inputs", "params", "_debug"]: + additional_input[key] = combined[key] + return additional_input + async def _arun( # noqa: C901,PLR0912 type: ignore self, query: Optional[str] = None, @@ -734,19 +758,27 @@ async def _arun( # noqa: C901,PLR0912 type: ignore **existing_input.get("_debug", {}), **node_output.get("_debug", {}), } - if query: + additional_input = self._combine_node_outputs(existing_input, node_output) + updated_input = {**additional_input, **updated_input} + if query and "query" not in updated_input: updated_input["query"] = query - if file_paths: + if file_paths and "file_paths" not in updated_input: updated_input["file_paths"] = file_paths - if labels: + if labels and "labels" not in updated_input: updated_input["labels"] = labels - if documents: + if documents and "documents" not in updated_input: updated_input["documents"] = documents - if meta: + if meta and "meta" not in updated_input: updated_input["meta"] = meta else: existing_input["inputs"].append(node_output) - updated_input = existing_input + if "_debug" in node_output.keys(): + existing_input["_debug"] = { + **existing_input.get("_debug", {}), + **node_output.get("_debug", {}), + } + additional_input = self._combine_node_outputs(existing_input, node_output) + updated_input = {**additional_input, **existing_input} queue[n] = updated_input else: queue[n] = node_output @@ -756,6 +788,7 @@ async def _arun( # noqa: C901,PLR0912 type: ignore return node_output + # pylint: disable=too-many-branches def run_batch( # noqa: C901,PLR0912 type: ignore self, queries: Optional[List[str]] = None, @@ -896,19 +929,32 @@ def run_batch( # noqa: C901,PLR0912 type: ignore existing_input = queue[n] if "inputs" not in existing_input.keys(): updated_input: Dict = {"inputs": [existing_input, node_output], "params": params} - if queries: + if "_debug" in existing_input.keys() or "_debug" in node_output.keys(): + updated_input["_debug"] = { + **existing_input.get("_debug", {}), + **node_output.get("_debug", {}), + } + additional_input = self._combine_node_outputs(existing_input, node_output) + updated_input = {**additional_input, **updated_input} + if queries and "queries" not in updated_input: updated_input["queries"] = queries - if file_paths: + if file_paths and "file_paths" not in updated_input: updated_input["file_paths"] = file_paths - if labels: + if labels and "labels" not in updated_input: updated_input["labels"] = labels - if documents: + if documents and "documents" not in updated_input: updated_input["documents"] = documents - if meta: + if meta and "meta" not in updated_input: updated_input["meta"] = meta else: existing_input["inputs"].append(node_output) - updated_input = existing_input + if "_debug" in node_output.keys(): + existing_input["_debug"] = { + **existing_input.get("_debug", {}), + **node_output.get("_debug", {}), + } + additional_input = self._combine_node_outputs(existing_input, node_output) + updated_input = {**additional_input, **existing_input} queue[n] = updated_input else: queue[n] = node_output @@ -2528,9 +2574,6 @@ def get_type(self) -> str: pipeline_types = { # QuestionGenerationPipeline has only one component, which is a QuestionGenerator "QuestionGenerationPipeline": lambda x: all(isinstance(x, QuestionGenerator) for x in x.values()), - # GenerativeQAPipeline has at least BaseGenerator and BaseRetriever components - "GenerativeQAPipeline": lambda x: any(isinstance(x, BaseRetriever) for x in x.values()) - and any(isinstance(x, BaseGenerator) for x in x.values()), # FAQPipeline has at least one Docs2Answers component "FAQPipeline": lambda x: any(isinstance(x, Docs2Answers) for x in x.values()), # ExtractiveQAPipeline has at least one BaseRetriever component and one BaseReader component diff --git a/haystack/pipelines/standard_pipelines.py b/haystack/pipelines/standard_pipelines.py index e1f8f61c72..0f2073cd3d 100644 --- a/haystack/pipelines/standard_pipelines.py +++ b/haystack/pipelines/standard_pipelines.py @@ -6,7 +6,6 @@ from typing import Any, Dict, List, Optional, Union, Literal from haystack.document_stores.base import BaseDocumentStore, FilterType -from haystack.nodes.answer_generator.base import BaseGenerator from haystack.nodes.other.docs2answers import Docs2Answers from haystack.nodes.other.document_merger import DocumentMerger from haystack.nodes.question_generator.question_generator import QuestionGenerator @@ -411,35 +410,6 @@ def run(self, query: str, params: Optional[dict] = None, debug: Optional[bool] = return output -class GenerativeQAPipeline(BaseStandardPipeline): - """ - Pipeline for Generative Question Answering. - """ - - def __init__(self, generator: BaseGenerator, retriever: BaseRetriever): - """ - :param generator: Generator instance - :param retriever: Retriever instance - """ - self.pipeline = Pipeline() - self.pipeline.add_node(component=retriever, name="Retriever", inputs=["Query"]) - self.pipeline.add_node(component=generator, name="Generator", inputs=["Retriever"]) - - def run(self, query: str, params: Optional[dict] = None, debug: Optional[bool] = None): - """ - :param query: the query string. - :param params: params for the `retriever` and `generator`. For instance, - params={"Retriever": {"top_k": 10}, "Generator": {"top_k": 5}} - :param debug: Whether the pipeline should instruct nodes to collect debug information - about their execution. By default these include the input parameters - they received and the output they generated. - All debug information can then be found in the dict returned - by this method under the key "_debug" - """ - output = self.pipeline.run(query=query, params=params, debug=debug) - return output - - class SearchSummarizationPipeline(BaseStandardPipeline): """ Pipeline that retrieves documents for a query and then summarizes those documents. diff --git a/haystack/preview/README.md b/haystack/preview/README.md deleted file mode 100644 index 1dcd4c351b..0000000000 --- a/haystack/preview/README.md +++ /dev/null @@ -1,34 +0,0 @@ -# Haystack 2.0 - Preview Features - -[![PyPI - Version](https://img.shields.io/pypi/v/haystack-ai.svg)](https://pypi.org/project/haystack-ai) -[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/haystack-ai.svg)](https://pypi.org/project/haystack-ai) - -Since Haystack 1.15, we’ve been slowly introducing new components and features to Haystack in the background in preparation for Haystack 2.0. In this `preview` module, you can find what’s been implemented so far regarding Haystack 2.0. **Keep in mind that Haystack 2.0 is still a work in progress.** Read more about Haystack 2.0 in [Shaping Haystack 2.0](https://github.com/deepset-ai/haystack/discussions/5568). - -## 💾 Installation - -**Install `haystack-ai`** - -There is a separate PyPI package that only ships the code in `preview` module. You can install `haystack-ai` using pip: -```sh -pip install haystack-ai -``` -The `haystack-ai` package is built on the `main` branch, so it's highly unstable, but it's useful if you want to try the new features as soon as they are merged. - -**Install `farm-haystack`** - -As an alternative way, you can install `farm-haystack`: -```sh -pip install farm-haystack -``` -The `farm-haystack` package includes all new features of Haystack 2.0. Note that updates to this package occur less frequently compared to `haystack-ai`. So, you might not get the all latest Haystack 2.0 features immediately when using `farm-haystack`. - -## 🚗 Getting Started - -In our **end 2 end tests** you can find example code for the following pipelines: -- [RAG pipeline](https://github.com/deepset-ai/haystack/blob/main/e2e/preview/pipelines/test_rag_pipelines.py) -- [Extractive QA pipeline](https://github.com/deepset-ai/haystack/blob/main/e2e/preview/pipelines/test_extractive_qa_pipeline.py) -- more to come, check out the [folder](https://github.com/deepset-ai/haystack/blob/main/e2e/preview/) - -## 💙 Stay Updated -To learn how and when components will be migrated to the new major version, have a look at the [Migrate Components to Pipeline v2](https://github.com/deepset-ai/haystack/issues/5265) roadmap item, where we keep track of issues and PRs about Haystack 2.0. When you have questions, you can always contact us using the [Shaping Haystack 2.0](https://github.com/deepset-ai/haystack/discussions/5568) discussion or [Haystack Discord server](https://discord.com/channels/993534733298450452/1141683185458094211). diff --git a/haystack/preview/__init__.py b/haystack/preview/__init__.py deleted file mode 100644 index 4d467a8e0a..0000000000 --- a/haystack/preview/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -from canals import component -from canals.serialization import default_from_dict, default_to_dict -from canals.errors import DeserializationError, ComponentError -from haystack.preview.pipeline import Pipeline -from haystack.preview.dataclasses import Document, Answer, GeneratedAnswer, ExtractedAnswer - - -__all__ = [ - "component", - "default_from_dict", - "default_to_dict", - "DeserializationError", - "ComponentError", - "Pipeline", - "Document", - "Answer", - "GeneratedAnswer", - "ExtractedAnswer", -] diff --git a/haystack/preview/components/__init__.py b/haystack/preview/components/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/haystack/preview/components/audio/__init__.py b/haystack/preview/components/audio/__init__.py deleted file mode 100644 index 3d3b07cd82..0000000000 --- a/haystack/preview/components/audio/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from haystack.preview.components.audio.whisper_local import LocalWhisperTranscriber -from haystack.preview.components.audio.whisper_remote import RemoteWhisperTranscriber - -__all__ = ["LocalWhisperTranscriber", "RemoteWhisperTranscriber"] diff --git a/haystack/preview/components/audio/whisper_local.py b/haystack/preview/components/audio/whisper_local.py deleted file mode 100644 index f45f652df2..0000000000 --- a/haystack/preview/components/audio/whisper_local.py +++ /dev/null @@ -1,140 +0,0 @@ -from typing import List, Optional, Dict, Any, Union, BinaryIO, Literal, get_args, Sequence - -import logging -from pathlib import Path - -from haystack.preview import component, Document, default_to_dict, ComponentError -from haystack.preview.lazy_imports import LazyImport - -with LazyImport( - "Run 'pip install transformers[torch]' to install torch and " - "'pip install \"openai-whisper>=20231106\"' to install whisper." -) as whisper_import: - import torch - import whisper - - -logger = logging.getLogger(__name__) -WhisperLocalModel = Literal["tiny", "small", "medium", "large", "large-v2"] - - -@component -class LocalWhisperTranscriber: - """ - Transcribes audio files using OpenAI's Whisper's model on your local machine. - - For the supported audio formats, languages, and other parameters, see the - [Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper - [github repo](https://github.com/openai/whisper). - """ - - def __init__( - self, - model_name_or_path: WhisperLocalModel = "large", - device: Optional[str] = None, - whisper_params: Optional[Dict[str, Any]] = None, - ): - """ - :param model_name_or_path: Name of the model to use. Set it to one of the following values: - :type model_name_or_path: Literal["tiny", "small", "medium", "large", "large-v2"] - :param device: Name of the torch device to use for inference. If None, CPU is used. - :type device: Optional[str] - """ - whisper_import.check() - if model_name_or_path not in get_args(WhisperLocalModel): - raise ValueError( - f"Model name '{model_name_or_path}' not recognized. Choose one among: " - f"{', '.join(get_args(WhisperLocalModel))}." - ) - self.model_name = model_name_or_path - self.whisper_params = whisper_params or {} - self.device = torch.device(device) if device else torch.device("cpu") - self._model = None - - def warm_up(self) -> None: - """ - Loads the model. - """ - if not self._model: - self._model = whisper.load_model(self.model_name, device=self.device) - - def to_dict(self) -> Dict[str, Any]: - """ - Serialize this component to a dictionary. - """ - return default_to_dict( - self, model_name_or_path=self.model_name, device=str(self.device), whisper_params=self.whisper_params - ) - - @component.output_types(documents=List[Document]) - def run(self, audio_files: List[Path], whisper_params: Optional[Dict[str, Any]] = None): - """ - Transcribe the audio files into a list of Documents, one for each input file. - - For the supported audio formats, languages, and other parameters, see the - [Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper - [github repo](https://github.com/openai/whisper). - - :param audio_files: A list of paths or binary streams to transcribe. - :returns: A list of Documents, one for each file. The content of the document is the transcription text, - while the document's metadata contains all the other values returned by the Whisper model, such as the - alignment data. Another key called `audio_file` contains the path to the audio file used for the - transcription. - """ - if self._model is None: - raise ComponentError("The component was not warmed up. Run 'warm_up()' before calling 'run()'.") - - if whisper_params is None: - whisper_params = self.whisper_params - - documents = self.transcribe(audio_files, **whisper_params) - return {"documents": documents} - - def transcribe(self, audio_files: Sequence[Union[str, Path, BinaryIO]], **kwargs) -> List[Document]: - """ - Transcribe the audio files into a list of Documents, one for each input file. - - For the supported audio formats, languages, and other parameters, see the - [Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper - [github repo](https://github.com/openai/whisper). - - :param audio_files: A list of paths or binary streams to transcribe. - :returns: A list of Documents, one for each file. The content of the document is the transcription text, - while the document's metadata contains all the other values returned by the Whisper model, such as the - alignment data. Another key called `audio_file` contains the path to the audio file used for the - transcription. - """ - transcriptions = self._raw_transcribe(audio_files=audio_files, **kwargs) - documents = [] - for audio, transcript in zip(audio_files, transcriptions): - content = transcript.pop("text") - if not isinstance(audio, (str, Path)): - audio = "<>" - doc = Document(content=content, meta={"audio_file": audio, **transcript}) - documents.append(doc) - return documents - - def _raw_transcribe(self, audio_files: Sequence[Union[str, Path, BinaryIO]], **kwargs) -> List[Dict[str, Any]]: - """ - Transcribe the given audio files. Returns the output of the model, a dictionary, for each input file. - - For the supported audio formats, languages, and other parameters, see the - [Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper - [github repo](https://github.com/openai/whisper). - - :param audio_files: A list of paths or binary streams to transcribe. - :returns: A list of transcriptions. - """ - return_segments = kwargs.pop("return_segments", False) - transcriptions = [] - for audio_file in audio_files: - if isinstance(audio_file, (str, Path)): - audio_file = open(audio_file, "rb") - - # mypy compains that _model is not guaranteed to be not None. It is: check self.warm_up() - transcription = self._model.transcribe(audio_file.name, **kwargs) # type: ignore - if not return_segments: - transcription.pop("segments", None) - transcriptions.append(transcription) - - return transcriptions diff --git a/haystack/preview/components/audio/whisper_remote.py b/haystack/preview/components/audio/whisper_remote.py deleted file mode 100644 index 848a7c170f..0000000000 --- a/haystack/preview/components/audio/whisper_remote.py +++ /dev/null @@ -1,141 +0,0 @@ -import io -import logging -import os -from typing import Any, Dict, List, Optional, Union -from pathlib import Path - -import openai - -from haystack.preview import Document, component, default_from_dict, default_to_dict -from haystack.preview.dataclasses import ByteStream - -logger = logging.getLogger(__name__) - - -API_BASE_URL = "https://api.openai.com/v1" - - -@component -class RemoteWhisperTranscriber: - """ - Transcribes audio files using OpenAI's Whisper using OpenAI API. Requires an API key. See the - [OpenAI blog post](https://beta.openai.com/docs/api-reference/whisper for more details. - You can get one by signing up for an [OpenAI account](https://beta.openai.com/). - - For the supported audio formats, languages, and other parameters, see the - [Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text) - """ - - def __init__( - self, - api_key: Optional[str] = None, - model_name: str = "whisper-1", - organization: Optional[str] = None, - api_base_url: str = API_BASE_URL, - **kwargs, - ): - """ - Transcribes a list of audio files into a list of Documents. - - :param api_key: OpenAI API key. - :param model_name: Name of the model to use. It now accepts only `whisper-1`. - :param organization: The OpenAI-Organization ID, defaults to `None`. For more details, see OpenAI - [documentation](https://platform.openai.com/docs/api-reference/requesting-organization). - :param api_base: OpenAI base URL, defaults to `"https://api.openai.com/v1"`. - :param kwargs: Other parameters to use for the model. These parameters are all sent directly to the OpenAI - endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/audio) for more details. - Some of the supported parameters: - - `language`: The language of the input audio. - Supplying the input language in ISO-639-1 format - will improve accuracy and latency. - - `prompt`: An optional text to guide the model's - style or continue a previous audio segment. - The prompt should match the audio language. - - `response_format`: The format of the transcript - output, in one of these options: json, text, srt, - verbose_json, or vtt. Defaults to "json". Currently only "json" is supported. - - `temperature`: The sampling temperature, between 0 - and 1. Higher values like 0.8 will make the output more - random, while lower values like 0.2 will make it more - focused and deterministic. If set to 0, the model will - use log probability to automatically increase the - temperature until certain thresholds are hit. - """ - - # if the user does not provide the API key, check if it is set in the module client - api_key = api_key or openai.api_key - if api_key is None: - try: - api_key = os.environ["OPENAI_API_KEY"] - except KeyError as e: - raise ValueError( - "RemoteWhisperTranscriber expects an OpenAI API key. " - "Set the OPENAI_API_KEY environment variable (recommended) or pass it explicitly." - ) from e - openai.api_key = api_key - - self.organization = organization - self.model_name = model_name - self.api_base_url = api_base_url - - # Only response_format = "json" is supported - whisper_params = kwargs - if whisper_params.get("response_format") != "json": - logger.warning( - "RemoteWhisperTranscriber only supports 'response_format: json'. This parameter will be overwritten." - ) - whisper_params["response_format"] = "json" - self.whisper_params = whisper_params - - if organization is not None: - openai.organization = organization - - def to_dict(self) -> Dict[str, Any]: - """ - Serialize this component to a dictionary. - This method overrides the default serializer in order to - avoid leaking the `api_key` value passed to the constructor. - """ - return default_to_dict( - self, - model_name=self.model_name, - organization=self.organization, - api_base_url=self.api_base_url, - **self.whisper_params, - ) - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "RemoteWhisperTranscriber": - """ - Deserialize this component from a dictionary. - """ - return default_from_dict(cls, data) - - @component.output_types(documents=List[Document]) - def run(self, sources: List[Union[str, Path, ByteStream]]): - """ - Transcribe the audio files into a list of Documents, one for each input file. - - For the supported audio formats, languages, and other parameters, see the - [Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper - [github repo](https://github.com/openai/whisper). - - :param audio_files: a list of ByteStream objects to transcribe. - :returns: a list of Documents, one for each file. The content of the document is the transcription text. - """ - documents = [] - - for source in sources: - if not isinstance(source, ByteStream): - path = source - source = ByteStream.from_file_path(Path(source)) - source.metadata["file_path"] = path - - file = io.BytesIO(source.data) - file.name = str(source.metadata["file_path"]) if "file_path" in source.metadata else "__fallback__.wav" - - content = openai.Audio.transcribe(file=file, model=self.model_name, **self.whisper_params) - doc = Document(content=content["text"], meta=source.metadata) - documents.append(doc) - - return {"documents": documents} diff --git a/haystack/preview/components/builders/__init__.py b/haystack/preview/components/builders/__init__.py deleted file mode 100644 index 6f47bce0d7..0000000000 --- a/haystack/preview/components/builders/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from haystack.preview.components.builders.answer_builder import AnswerBuilder -from haystack.preview.components.builders.prompt_builder import PromptBuilder -from haystack.preview.components.builders.dynamic_prompt_builder import DynamicPromptBuilder - -__all__ = ["AnswerBuilder", "PromptBuilder", "DynamicPromptBuilder"] diff --git a/haystack/preview/components/builders/answer_builder.py b/haystack/preview/components/builders/answer_builder.py deleted file mode 100644 index 0f44a33bc3..0000000000 --- a/haystack/preview/components/builders/answer_builder.py +++ /dev/null @@ -1,142 +0,0 @@ -import logging -import re -from typing import List, Dict, Any, Optional - -from haystack.preview import component, GeneratedAnswer, Document - - -logger = logging.getLogger(__name__) - - -@component -class AnswerBuilder: - """ - A component to parse the output of a Generator to `Answer` objects using regular expressions. - """ - - def __init__(self, pattern: Optional[str] = None, reference_pattern: Optional[str] = None): - """ - :param pattern: The regular expression pattern to use to extract the answer text from the generator output. - If not specified, the whole string is used as the answer. The regular expression can have at - most one capture group. If a capture group is present, the text matched by the capture group - is used as the answer. If no capture group is present, the whole match is used as the answer. - Examples: - `[^\\n]+$` finds "this is an answer" in a string "this is an argument.\nthis is an answer". - `Answer: (.*)` finds "this is an answer" in a string "this is an argument. Answer: this is an answer". - Default: `None`. - :param reference_pattern: The regular expression pattern to use for parsing the document references. - We assume that references are specified as indices of the input documents and that - indices start at 1. - Example: `\\[(\\d+)\\]` finds "1" in a string "this is an answer[1]". - If not specified, no parsing is done, and all documents are referenced. - Default: `None`. - """ - if pattern: - AnswerBuilder._check_num_groups_in_regex(pattern) - - self.pattern = pattern - self.reference_pattern = reference_pattern - - @component.output_types(answers=List[GeneratedAnswer]) - def run( - self, - query: str, - replies: List[str], - metadata: Optional[List[Dict[str, Any]]] = None, - documents: Optional[List[Document]] = None, - pattern: Optional[str] = None, - reference_pattern: Optional[str] = None, - ): - """ - Parse the output of a Generator to `Answer` objects using regular expressions. - - :param query: The query used in the prompts for the Generator. A strings. - :param replies: The output of the Generator. A list of strings. - :param metadata: The metadata returned by the Generator. An optional list of dictionaries. If not specified, - the generated answer will contain no metadata. - :param documents: The documents used as input to the Generator. A list of `Document` objects. If - `documents` are specified, they are added to the `Answer` objects. - If both `documents` and `reference_pattern` are specified, the documents referenced in the - Generator output are extracted from the input documents and added to the `Answer` objects. - Default: `None`. - :param pattern: The regular expression pattern to use to extract the answer text from the generator output. - If not specified, the whole string is used as the answer. The regular expression can have at - most one capture group. If a capture group is present, the text matched by the capture group - is used as the answer. If no capture group is present, the whole match is used as the answer. - Examples: - `[^\\n]+$` finds "this is an answer" in a string "this is an argument.\nthis is an answer". - `Answer: (.*)` finds "this is an answer" in a string "this is an argument. Answer: this is an answer". - Default: `None`. - :param reference_pattern: The regular expression pattern to use for parsing the document references. - We assume that references are specified as indices of the input documents and that - indices start at 1. - Example: `\\[(\\d+)\\]` finds "1" in a string "this is an answer[1]". - If not specified, no parsing is done, and all documents are referenced. - Default: `None`. - """ - if not metadata: - metadata = [{}] * len(replies) - elif len(replies) != len(metadata): - raise ValueError(f"Number of replies ({len(replies)}), and metadata ({len(metadata)}) must match.") - - if pattern: - AnswerBuilder._check_num_groups_in_regex(pattern) - - pattern = pattern or self.pattern - reference_pattern = reference_pattern or self.reference_pattern - - all_answers = [] - for reply, meta in zip(replies, metadata): - referenced_docs = [] - if documents: - reference_idxs = [] - if reference_pattern: - reference_idxs = AnswerBuilder._extract_reference_idxs(reply, reference_pattern) - else: - reference_idxs = [doc_idx for doc_idx, _ in enumerate(documents)] - - for idx in reference_idxs: - try: - referenced_docs.append(documents[idx]) - except IndexError: - logger.warning("Document index '%s' referenced in Generator output is out of range. ", idx + 1) - - answer_string = AnswerBuilder._extract_answer_string(reply, pattern) - answer = GeneratedAnswer(data=answer_string, query=query, documents=referenced_docs, metadata=meta) - all_answers.append(answer) - - return {"answers": all_answers} - - @staticmethod - def _extract_answer_string(reply: str, pattern: Optional[str] = None) -> str: - """ - Extract the answer string from the generator output using the specified pattern. - If no pattern is specified, the whole string is used as the answer. - - :param replies: The output of the Generator. A string. - :param pattern: The regular expression pattern to use to extract the answer text from the generator output. - """ - if pattern is None: - return reply - - if match := re.search(pattern, reply): - # No capture group in pattern -> use the whole match as answer - if not match.lastindex: - return match.group(0) - # One capture group in pattern -> use the capture group as answer - return match.group(1) - return "" - - @staticmethod - def _extract_reference_idxs(reply: str, reference_pattern: str) -> List[int]: - document_idxs = re.findall(reference_pattern, reply) - return [int(idx) - 1 for idx in document_idxs] - - @staticmethod - def _check_num_groups_in_regex(pattern: str): - num_groups = re.compile(pattern).groups - if num_groups > 1: - raise ValueError( - f"Pattern '{pattern}' contains multiple capture groups. " - f"Please specify a pattern with at most one capture group." - ) diff --git a/haystack/preview/components/builders/dynamic_prompt_builder.py b/haystack/preview/components/builders/dynamic_prompt_builder.py deleted file mode 100644 index 36ba15f804..0000000000 --- a/haystack/preview/components/builders/dynamic_prompt_builder.py +++ /dev/null @@ -1,331 +0,0 @@ -import logging -from typing import Dict, Any, Optional, List, Union, Set - -from jinja2 import Template, meta - -from haystack.preview import component -from haystack.preview import default_to_dict -from haystack.preview.dataclasses.chat_message import ChatMessage, ChatRole - -logger = logging.getLogger(__name__) - - -@component -class DynamicPromptBuilder: - """ - DynamicPromptBuilder is designed to construct dynamic prompts by processing either a list of `ChatMessage` - instances or a string template. It integrates with Jinja2 templating for dynamic prompt generation. - - In the case of `ChatMessage` instances, DynamicPromptBuilder assumes the last user message in the list as a - template and renders it with resolved pipeline variables and any additional template variables provided. For a - string template, it applies the template variables directly to render the final prompt. This dual functionality - allows DynamicPromptBuilder to be versatile in handling different types of prompt sources, making it suitable for - both chat-based and non-chat-based prompt generation scenarios. - - You can provide additional template variables directly to the pipeline `run` method. They are then merged with the - variables resolved from the pipeline runtime. This allows for greater flexibility and customization of the - generated prompts based on runtime conditions and user inputs. - - The following example demonstrates how to use DynamicPromptBuilder to generate a chat prompt: - - ```python - from haystack.preview.components.builders import DynamicPromptBuilder - from haystack.preview.components.generators.chat import GPTChatGenerator - from haystack.preview.dataclasses import ChatMessage - from haystack.preview import Pipeline - - # no parameter init, we don't use any runtime template variables - prompt_builder = DynamicPromptBuilder() - llm = GPTChatGenerator(api_key="", model_name="gpt-3.5-turbo") - - pipe = Pipeline() - pipe.add_component("prompt_builder", prompt_builder) - pipe.add_component("llm", llm) - pipe.connect("prompt_builder.prompt", "llm.messages") - - location = "Berlin" - messages = [ChatMessage.from_system("Always respond in German even if some input data is in other languages."), - ChatMessage.from_user("Tell me about {{location}}")] - - - pipe.run(data={"prompt_builder": {"template_variables":{"location": location}, "prompt_source": messages}}) - - >> {'llm': {'replies': [ChatMessage(content='Berlin ist die Hauptstadt Deutschlands und die größte Stadt des Landes. - >> Es ist eine lebhafte Metropole, die für ihre Geschichte, Kultur und einzigartigen Sehenswürdigkeiten bekannt ist. - >> Berlin bietet eine vielfältige Kulturszene, beeindruckende architektonische Meisterwerke wie den Berliner Dom - >> und das Brandenburger Tor, sowie weltberühmte Museen wie das Pergamonmuseum. Die Stadt hat auch eine pulsierende - >> Clubszene und ist für ihr aufregendes Nachtleben berühmt. Berlin ist ein Schmelztiegel verschiedener Kulturen und - >> zieht jedes Jahr Millionen von Touristen an.', role=, name=None, - >> metadata={'model': 'gpt-3.5-turbo-0613', 'index': 0, 'finish_reason': 'stop', 'usage': {'prompt_tokens': 32, - >> 'completion_tokens': 153, 'total_tokens': 185}})]}} - ``` - - The following example demonstrates how to use DynamicPromptBuilder to generate a chat prompt with resolution - of pipeline runtime variables (such as documents): - - ```python - from haystack.preview.components.builders import DynamicPromptBuilder - from haystack.preview.components.generators.chat import GPTChatGenerator - from haystack.preview.dataclasses import ChatMessage, Document - from haystack.preview import Pipeline, component - from typing import List - - # we'll use documents runtime variable in our template, so we need to specify it in the init - prompt_builder = DynamicPromptBuilder(runtime_variables=["documents"]) - llm = GPTChatGenerator(api_key="", model_name="gpt-3.5-turbo") - - - @component - class DocumentProducer: - - @component.output_types(documents=List[Document]) - def run(self, doc_input: str): - return {"documents": [Document(content=doc_input)]} - - - - pipe = Pipeline() - pipe.add_component("doc_producer", DocumentProducer()) - pipe.add_component("prompt_builder", prompt_builder) - pipe.add_component("llm", llm) - - # note here how prompt_builder.documents is received from doc_producer.documents - pipe.connect("doc_producer.documents", "prompt_builder.documents") - pipe.connect("prompt_builder.prompt", "llm.messages") - - messages = [ChatMessage.from_system("Be helpful assistant, but brief!"), - ChatMessage.from_user("Here is the document: {{documents[0].content}} Now, answer the - following: {{query}}")] - - - pipe.run(data={"doc_producer": {"doc_input": "Hello world, I'm Haystack!"}, - "prompt_builder": {"prompt_source": messages, - "template_variables":{"query": "who's making a greeting?"}}}) - - >> {'llm': {'replies': [ChatMessage(content='Haystack', role=, name=None, - >> metadata={'model': 'gpt-3.5-turbo-0613', 'index': 0, 'finish_reason': 'stop', 'usage': - >> {'prompt_tokens': 51, 'completion_tokens': 2, 'total_tokens': 53}})]}} - ``` - - Similarly to chat prompt generation, you can use DynamicPromptBuilder to generate non-chat-based prompts. - The following example demonstrates how to use DynamicPromptBuilder to generate a non-chat prompt: - - ```python - prompt_builder = DynamicPromptBuilder(runtime_variables=["documents"], chat_mode=False) - llm = GPTGenerator(api_key="", model_name="gpt-3.5-turbo") - - - @component - class DocumentProducer: - - @component.output_types(documents=List[Document]) - def run(self, doc_input: str): - return {"documents": [Document(content=doc_input)]} - - - pipe = Pipeline() - pipe.add_component("doc_producer", DocumentProducer()) - pipe.add_component("prompt_builder", prompt_builder) - pipe.add_component("llm", llm) - pipe.connect("doc_producer.documents", "prompt_builder.documents") - pipe.connect("prompt_builder.prompt", "llm.prompt") - - template = "Here is the document: {{documents[0].content}} \n Answer: {{query}}" - pipe.run(data={"doc_producer": {"doc_input": "Hello world, I live in Berlin"}, - "prompt_builder": {"prompt_source": template, - "template_variables":{"query": "Where does the speaker live?"}}}) - - >> {'llm': {'replies': ['The speaker lives in Berlin.'], - >> 'metadata': [{'model': 'gpt-3.5-turbo-0613', - >> 'index': 0, - >> 'finish_reason': 'stop', - >> 'usage': {'prompt_tokens': 28, - >> 'completion_tokens': 6, - >> 'total_tokens': 34}}]}} - - """ - - def __init__(self, runtime_variables: Optional[List[str]] = None, chat_mode: Optional[bool] = True): - """ - Initializes DynamicPromptBuilder with the provided variable names. These variable names are used to resolve - variables and their values during pipeline runtime execution. For example, if `runtime_variables` contains - `documents` your instance of DynamicPromptBuilder will expect an input called `documents`. - The values associated with variables from the pipeline runtime are then injected into template placeholders - of either a ChatMessage or a string template that is provided to the `run` method. - See `run` method for more details. - - :param runtime_variables: A list of template variable names you can use in chat prompt construction. - :type runtime_variables: Optional[List[str]] - :param chat_mode: A boolean flag to indicate if the chat prompt is being built for a chat-based prompt - templating. Defaults to True. - :type chat_mode: Optional[bool] - """ - runtime_variables = runtime_variables or [] - - if not runtime_variables: - logger.warning( - "template_variables were not provided, DynamicPromptBuilder will not resolve any pipeline variables." - ) - # setup inputs - if chat_mode: - run_input_slots = {"prompt_source": List[ChatMessage], "template_variables": Optional[Dict[str, Any]]} - else: - run_input_slots = {"prompt_source": str, "template_variables": Optional[Dict[str, Any]]} - - kwargs_input_slots = {var: Optional[Any] for var in runtime_variables} - component.set_input_types(self, **run_input_slots, **kwargs_input_slots) - - # setup outputs - if chat_mode: - component.set_output_types(self, prompt=List[ChatMessage]) - else: - component.set_output_types(self, prompt=str) - - self.runtime_variables = runtime_variables - self.chat_mode = chat_mode - - def to_dict(self) -> Dict[str, Any]: - """ - Converts the `DynamicPromptBuilder` instance to a dictionary format, primarily for serialization purposes. - - :return: A dictionary representation of the `DynamicPromptBuilder` instance, including its template variables. - :rtype: Dict[str, Any] - """ - return default_to_dict(self, runtime_variables=self.runtime_variables, chat_mode=self.chat_mode) - - def run( - self, - prompt_source: Union[List[ChatMessage], str], - template_variables: Optional[Dict[str, Any]] = None, - **kwargs, - ): - """ - Executes the dynamic prompt building process. Depending on the provided type of `prompt_source`, this method - either processes a list of `ChatMessage` instances or a string template. In the case of `ChatMessage` instances, - the last user message is treated as a template and rendered with the resolved pipeline variables and any - additional template variables provided. For a string template, it directly applies the template variables to - render the final prompt. You can provide additional template variables directly to this method, that are then merged - with the variables resolved from the pipeline runtime. - - :param prompt_source: A list of `ChatMessage` instances or a string template. The list scenario assumes the last - user message as the template for the chat prompt, while the string scenario is used for non-chat-based prompts. - :type prompt_source: Union[List[ChatMessage], str] - - :param template_variables: An optional dictionary of template variables. Template variables provided at - initialization are required to resolve pipeline variables, and these are additional variables users can - provide directly to this method. - :type template_variables: Optional[Dict[str, Any]] - - :param kwargs: Additional keyword arguments, typically resolved from a pipeline, which are merged with the - provided template variables. - - :return: A dictionary containing the key "prompt", which holds either the updated list of `ChatMessage` - instances or the rendered string template, forming the complete dynamic prompt. - :rtype: Dict[str, Union[List[ChatMessage], str]] - """ - kwargs = kwargs or {} - template_variables = template_variables or {} - template_variables_combined = {**kwargs, **template_variables} - if not template_variables_combined: - raise ValueError( - "The DynamicPromptBuilder run method requires template variables, but none were provided. " - "Please provide an appropriate template variable to enable prompt generation." - ) - # some of these checks are superfluous because pipeline will check them as well but let's - # handle them anyway for better error messages and robustness - result: Union[List[ChatMessage], str] - if isinstance(prompt_source, str): - result = self._process_simple_template(prompt_source, template_variables_combined) - elif isinstance(prompt_source, list): - result = self._process_chat_messages(prompt_source, template_variables_combined) - else: - raise ValueError( - f"{self.__class__.__name__} was not provided with a list of ChatMessage(s) or a string template." - "Please check the parameters passed to its run method." - ) - return {"prompt": result} - - def _process_simple_template(self, prompt_source: str, template_variables: Dict[str, Any]) -> str: - """ - Renders the template from the provided string source with the provided template variables. - - :param prompt_source: A Jinja2 template as a string. - :type prompt_source: str - :param template_variables: A dictionary of template variables. - :type template_variables: Dict[str, Any] - :return: A string containing the rendered template. - :rtype: str - """ - template = self._validate_template(prompt_source, set(template_variables.keys())) - return template.render(template_variables) - - def _process_chat_messages(self, prompt_source: List[ChatMessage], template_variables: Dict[str, Any]): - """ - Processes a list of :class:`ChatMessage` instances to generate a chat prompt. - - It takes the last user message in the list, treats it as a template, and renders it with the provided - template variables. The resulting message replaces the last user message in the list, forming a complete, - templated chat prompt. - - :param prompt_source: A list of `ChatMessage` instances to be processed. The last message is expected - to be from a user and is treated as a template. - :type prompt_source: List[ChatMessage] - - :param template_variables: A dictionary of template variables used for rendering the last user message. - :type template_variables: Dict[str, Any] - - :return: A list of `ChatMessage` instances, where the last user message has been replaced with its - templated version. - :rtype: List[ChatMessage] - - :raises ValueError: If `chat_messages` is empty or contains elements that are not instances of - `ChatMessage`. - :raises ValueError: If the last message in `chat_messages` is not from a user. - """ - if not prompt_source: - raise ValueError( - f"The {self.__class__.__name__} requires a non-empty list of ChatMessage instances. " - f"Please provide a valid list of ChatMessage instances to render the prompt." - ) - if not all(isinstance(message, ChatMessage) for message in prompt_source): - raise ValueError( - f"The {self.__class__.__name__} expects a list containing only ChatMessage instances. " - f"The provided list contains other types. Please ensure that all elements in the list " - f"are ChatMessage instances." - ) - - last_message: ChatMessage = prompt_source[-1] - if last_message.is_from(ChatRole.USER): - template = self._validate_template(last_message.content, set(template_variables.keys())) - templated_user_message = ChatMessage.from_user(template.render(template_variables)) - return prompt_source[:-1] + [templated_user_message] - else: - logger.warning( - "DynamicPromptBuilder was not provided with a user message as the last message in " - "chat conversation, no templating will be applied." - ) - return prompt_source - - def _validate_template(self, template_text: str, provided_variables: Set[str]): - """ - Checks if all the required template variables are provided to the pipeline `run` method. - If all the required template variables are provided, returns a Jinja2 template object. - Otherwise, raises a ValueError. - - :param template_text: A Jinja2 template as a string. - :param provided_variables: A set of provided template variables. - :type provided_variables: Set[str] - :return: A Jinja2 template object if all the required template variables are provided. - :raises ValueError: If all the required template variables are not provided. - """ - template = Template(template_text) - ast = template.environment.parse(template_text) - required_template_variables = meta.find_undeclared_variables(ast) - filled_template_vars = required_template_variables.intersection(provided_variables) - if len(filled_template_vars) != len(required_template_variables): - raise ValueError( - f"The {self.__class__.__name__} requires specific template variables that are missing. " - f"Required variables: {required_template_variables}. Only the following variables were " - f"provided: {provided_variables}. Please provide all the required template variables." - ) - return template diff --git a/haystack/preview/components/builders/prompt_builder.py b/haystack/preview/components/builders/prompt_builder.py deleted file mode 100644 index b6e0da6fb4..0000000000 --- a/haystack/preview/components/builders/prompt_builder.py +++ /dev/null @@ -1,41 +0,0 @@ -from typing import Dict, Any - -from jinja2 import Template, meta - -from haystack.preview import component -from haystack.preview import default_to_dict - - -@component -class PromptBuilder: - """ - PromptBuilder is a component that renders a prompt from a template string using Jinja2 engine. - The template variables found in the template string are used as input types for the component and are all required. - - Usage: - ```python - template = "Translate the following context to {{ target_language }}. Context: {{ snippet }}; Translation:" - builder = PromptBuilder(template=template) - builder.run(target_language="spanish", snippet="I can't speak spanish.") - ``` - """ - - def __init__(self, template: str): - """ - Initialize the component with a template string. - - :param template: Jinja2 template string, e.g. "Summarize this document: {documents}\nSummary:" - :type template: str - """ - self._template_string = template - self.template = Template(template) - ast = self.template.environment.parse(template) - template_variables = meta.find_undeclared_variables(ast) - component.set_input_types(self, **{var: Any for var in template_variables}) - - def to_dict(self) -> Dict[str, Any]: - return default_to_dict(self, template=self._template_string) - - @component.output_types(prompt=str) - def run(self, **kwargs): - return {"prompt": self.template.render(kwargs)} diff --git a/haystack/preview/components/caching/__init__.py b/haystack/preview/components/caching/__init__.py deleted file mode 100644 index d2e8a69c1f..0000000000 --- a/haystack/preview/components/caching/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from haystack.preview.components.caching.url_cache_checker import UrlCacheChecker - -__all__ = ["UrlCacheChecker"] diff --git a/haystack/preview/components/caching/url_cache_checker.py b/haystack/preview/components/caching/url_cache_checker.py deleted file mode 100644 index c3d87bcfcc..0000000000 --- a/haystack/preview/components/caching/url_cache_checker.py +++ /dev/null @@ -1,63 +0,0 @@ -from typing import List, Dict, Any - -from haystack.preview import component, Document, default_from_dict, default_to_dict, DeserializationError -from haystack.preview.document_stores import DocumentStore, document_store - - -@component -class UrlCacheChecker: - """ - A component checks for the presence of a document from a specific URL in the store. UrlCacheChecker can thus - implement caching functionality within web retrieval pipelines that use a Document Store. - """ - - def __init__(self, document_store: DocumentStore, url_field: str = "url"): - """ - Create a UrlCacheChecker component. - """ - self.document_store = document_store - self.url_field = url_field - - def to_dict(self) -> Dict[str, Any]: - """ - Serialize this component to a dictionary. - """ - return default_to_dict(self, document_store=self.document_store.to_dict(), url_field=self.url_field) - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "UrlCacheChecker": - """ - Deserialize this component from a dictionary. - """ - init_params = data.get("init_parameters", {}) - if "document_store" not in init_params: - raise DeserializationError("Missing 'document_store' in serialization data") - if "type" not in init_params["document_store"]: - raise DeserializationError("Missing 'type' in document store's serialization data") - if init_params["document_store"]["type"] not in document_store.registry: - raise DeserializationError(f"DocumentStore of type '{init_params['document_store']['type']}' not found.") - docstore_class = document_store.registry[init_params["document_store"]["type"]] - docstore = docstore_class.from_dict(init_params["document_store"]) - - data["init_parameters"]["document_store"] = docstore - return default_from_dict(cls, data) - - @component.output_types(hits=List[Document], misses=List[str]) - def run(self, urls: List[str]): - """ - Checks if any document coming from the given URL is already present in the store. If matching documents are - found, they are returned. If not, the URL is returned as a miss. - - :param urls: All the URLs the documents may be coming from to hit this cache. - """ - found_documents = [] - missing_urls = [] - - for url in urls: - filters = {self.url_field: url} - found = self.document_store.filter_documents(filters=filters) - if found: - found_documents.extend(found) - else: - missing_urls.append(url) - return {"hits": found_documents, "misses": missing_urls} diff --git a/haystack/preview/components/classifiers/__init__.py b/haystack/preview/components/classifiers/__init__.py deleted file mode 100644 index 6a4cfaee8d..0000000000 --- a/haystack/preview/components/classifiers/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from haystack.preview.components.classifiers.document_language_classifier import DocumentLanguageClassifier - -__all__ = ["DocumentLanguageClassifier"] diff --git a/haystack/preview/components/classifiers/document_language_classifier.py b/haystack/preview/components/classifiers/document_language_classifier.py deleted file mode 100644 index 5a04b4675b..0000000000 --- a/haystack/preview/components/classifiers/document_language_classifier.py +++ /dev/null @@ -1,82 +0,0 @@ -import logging -from typing import List, Dict, Optional - -from haystack.preview import component, Document -from haystack.preview.lazy_imports import LazyImport - -logger = logging.getLogger(__name__) - -with LazyImport("Run 'pip install langdetect'") as langdetect_import: - import langdetect - - -@component -class DocumentLanguageClassifier: - """ - Classify the language of documents and add the detected language to their metadata. - A MetadataRouter can then route them onto different output connections depending on their language. - This is useful to route documents to different models in a pipeline depending on their language. - The set of supported languages can be specified. - For routing plain text using the same logic, use the related TextLanguageRouter component instead. - - Example usage within an indexing pipeline, storing in a Document Store - only documents written in English: - - ```python - document_store = InMemoryDocumentStore() - p = Pipeline() - p.add_component(instance=TextFileToDocument(), name="text_file_converter") - p.add_component(instance=DocumentLanguageClassifier(), name="language_classifier") - p.add_component(instance=MetadataRouter(rules={"en": {"language": {"$eq": "en"}}}), name="router") - p.add_component(instance=DocumentWriter(document_store=document_store), name="writer") - p.connect("text_file_converter.documents", "language_classifier.documents") - p.connect("language_classifier.documents", "router.documents") - p.connect("router.en", "writer.documents") - ``` - """ - - def __init__(self, languages: Optional[List[str]] = None): - """ - :param languages: A list of languages in ISO code, each corresponding to a different output connection - (see [langdetect` documentation](https://github.com/Mimino666/langdetect#languages)). - By default, only ["en"] is supported and Documents of any other language are routed to "unmatched". - """ - langdetect_import.check() - if not languages: - languages = ["en"] - self.languages = languages - - @component.output_types(documents=List[Document]) - def run(self, documents: List[Document]): - """ - Run the DocumentLanguageClassifier. This method classifies the documents' language and adds it to their metadata. - If a Document's text does not match any of the languages specified at initialization, the metadata value "unmatched" will be stored. - - :param documents: A list of documents to classify their language. - :return: List of Documents with an added metadata field called language. - """ - if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): - raise TypeError( - "DocumentLanguageClassifier expects a list of Document as input. " - "In case you want to classify a text, please use the TextLanguageClassifier." - ) - - output: Dict[str, List[Document]] = {language: [] for language in self.languages} - output["unmatched"] = [] - - for document in documents: - detected_language = self.detect_language(document) - if detected_language in self.languages: - document.meta["language"] = detected_language - else: - document.meta["language"] = "unmatched" - - return {"documents": documents} - - def detect_language(self, document: Document) -> Optional[str]: - try: - language = langdetect.detect(document.content) - except langdetect.LangDetectException: - logger.warning("Langdetect cannot detect the language of Document with id: %s", document.id) - language = None - return language diff --git a/haystack/preview/components/converters/__init__.py b/haystack/preview/components/converters/__init__.py deleted file mode 100644 index 62367b2222..0000000000 --- a/haystack/preview/components/converters/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from haystack.preview.components.converters.txt import TextFileToDocument -from haystack.preview.components.converters.tika import TikaDocumentConverter -from haystack.preview.components.converters.azure import AzureOCRDocumentConverter -from haystack.preview.components.converters.pypdf import PyPDFToDocument -from haystack.preview.components.converters.html import HTMLToDocument -from haystack.preview.components.converters.markdown import MarkdownToDocument - -__all__ = [ - "TextFileToDocument", - "TikaDocumentConverter", - "AzureOCRDocumentConverter", - "PyPDFToDocument", - "HTMLToDocument", - "MarkdownToDocument", -] diff --git a/haystack/preview/components/converters/azure.py b/haystack/preview/components/converters/azure.py deleted file mode 100644 index 304078d7d2..0000000000 --- a/haystack/preview/components/converters/azure.py +++ /dev/null @@ -1,105 +0,0 @@ -from pathlib import Path -from typing import List, Union, Dict, Any, Optional -import os - -from haystack.preview.lazy_imports import LazyImport -from haystack.preview import component, Document, default_to_dict - - -with LazyImport(message="Run 'pip install \"azure-ai-formrecognizer>=3.2.0b2\"'") as azure_import: - from azure.ai.formrecognizer import DocumentAnalysisClient, AnalyzeResult - from azure.core.credentials import AzureKeyCredential - - -@component -class AzureOCRDocumentConverter: - """ - A component for converting files to Documents using Azure's Document Intelligence service. - Supported file formats are: PDF, JPEG, PNG, BMP, TIFF, DOCX, XLSX, PPTX, and HTML. - - In order to be able to use this component, you need an active Azure account - and a Document Intelligence or Cognitive Services resource. Please follow the steps described in the - [Azure documentation](https://learn.microsoft.com/en-us/azure/ai-services/document-intelligence/quickstarts/get-started-sdks-rest-api) - to set up your resource. - """ - - def __init__(self, endpoint: str, api_key: Optional[str] = None, model_id: str = "prebuilt-read"): - """ - Create an AzureOCRDocumentConverter component. - - :param endpoint: The endpoint of your Azure resource. - :param api_key: The key of your Azure resource. It can be - explicitly provided or automatically read from the - environment variable AZURE_AI_API_KEY (recommended). - :param model_id: The model ID of the model you want to use. Please refer to [Azure documentation](https://learn.microsoft.com/en-us/azure/ai-services/document-intelligence/choose-model-feature) - for a list of available models. Default: `"prebuilt-read"`. - """ - azure_import.check() - - if api_key is None: - try: - api_key = os.environ["AZURE_AI_API_KEY"] - except KeyError as e: - raise ValueError( - "AzureOCRDocumentConverter expects an Azure Credential key. " - "Set the AZURE_AI_API_KEY environment variable (recommended) or pass it explicitly." - ) from e - - self.api_key = api_key - self.document_analysis_client = DocumentAnalysisClient( - endpoint=endpoint, credential=AzureKeyCredential(api_key) - ) - self.endpoint = endpoint - self.model_id = model_id - - @component.output_types(documents=List[Document], azure=List[Dict]) - def run(self, paths: List[Union[str, Path]]): - """ - Convert files to Documents using Azure's Document Intelligence service. - - This component creates two outputs: `documents` and `raw_azure_response`. The `documents` output contains - a list of Documents that were created from the files. The `raw_azure_response` output contains a list of - the raw responses from Azure's Document Intelligence service. - - :param paths: Paths to the files to convert. - """ - documents = [] - azure_output = [] - for path in paths: - path = Path(path) - with open(path, "rb") as file: - poller = self.document_analysis_client.begin_analyze_document(model_id=self.model_id, document=file) - result = poller.result() - azure_output.append(result.to_dict()) - - file_suffix = path.suffix - document = AzureOCRDocumentConverter._convert_azure_result_to_document(result, file_suffix) - documents.append(document) - - return {"documents": documents, "raw_azure_response": azure_output} - - def to_dict(self) -> Dict[str, Any]: - """ - Serialize this component to a dictionary. - """ - return default_to_dict(self, endpoint=self.endpoint, model_id=self.model_id) - - @staticmethod - def _convert_azure_result_to_document(result: "AnalyzeResult", file_suffix: str) -> Document: - """ - Convert the result of Azure OCR to a Haystack text Document. - """ - if file_suffix == ".pdf": - text = "" - for page in result.pages: - lines = page.lines if page.lines else [] - for line in lines: - text += f"{line.content}\n" - - text += "\f" - else: - text = result.content - - document = Document(content=text) - - return document diff --git a/haystack/preview/components/converters/html.py b/haystack/preview/components/converters/html.py deleted file mode 100644 index 8b68119b38..0000000000 --- a/haystack/preview/components/converters/html.py +++ /dev/null @@ -1,94 +0,0 @@ -import logging -from pathlib import Path -from typing import Any, Dict, List, Optional, Union - -from haystack.preview import Document, component -from haystack.preview.dataclasses import ByteStream -from haystack.preview.lazy_imports import LazyImport - -logger = logging.getLogger(__name__) - -with LazyImport("Run 'pip install boilerpy3'") as boilerpy3_import: - from boilerpy3 import extractors - - -@component -class HTMLToDocument: - """ - Converts an HTML file to a Document. - - Usage example: - ```python - from haystack.preview.components.converters.html import HTMLToDocument - - converter = HTMLToDocument() - results = converter.run(sources=["sample.html"]) - documents = results["documents"] - print(documents[0].content) - # 'This is a text from the HTML file.' - ``` - - """ - - def __init__(self): - """ - Initializes the HTMLToDocument component. - """ - boilerpy3_import.check() - - @component.output_types(documents=List[Document]) - def run(self, sources: List[Union[str, Path, ByteStream]], meta: Optional[List[Dict[str, Any]]] = None): - """ - Converts a list of HTML files to Documents. - - :param sources: List of HTML file paths or ByteStream objects. - :param meta: Optional list of metadata to attach to the Documents. - The length of the list must match the number of sources. Defaults to `None`. - :return: List of converted Documents. - """ - - documents = [] - - # Create metadata placeholders if not provided - if meta: - if len(sources) != len(meta): - raise ValueError("The length of the metadata list must match the number of sources.") - else: - meta = [{}] * len(sources) - - extractor = extractors.ArticleExtractor(raise_on_failure=False) - - for source, metadata in zip(sources, meta): - try: - file_content, extracted_meta = self._extract_content(source) - except Exception as e: - logger.warning("Could not read %s. Skipping it. Error: %s", source, e) - continue - try: - text = extractor.get_content(file_content) - except Exception as conversion_e: # Consider specifying the expected exception type(s) here - logger.warning("Failed to extract text from %s. Skipping it. Error: %s", source, conversion_e) - continue - - # Merge metadata received from ByteStream with supplied metadata - if extracted_meta: - # Supplied metadata overwrites metadata from ByteStream for overlapping keys. - metadata = {**extracted_meta, **metadata} - document = Document(content=text, meta=metadata) - documents.append(document) - - return {"documents": documents} - - def _extract_content(self, source: Union[str, Path, ByteStream]) -> tuple: - """ - Extracts content from the given data source - :param source: The data source to extract content from. - :return: The extracted content and metadata. - """ - if isinstance(source, (str, Path)): - with open(source) as text_file: - return (text_file.read(), None) - if isinstance(source, ByteStream): - return (source.data.decode("utf-8"), source.metadata) - - raise ValueError(f"Unsupported source type: {type(source)}") diff --git a/haystack/preview/components/converters/markdown.py b/haystack/preview/components/converters/markdown.py deleted file mode 100644 index 145b78a910..0000000000 --- a/haystack/preview/components/converters/markdown.py +++ /dev/null @@ -1,97 +0,0 @@ -import logging -from pathlib import Path -from typing import Any, Dict, List, Optional, Union - -from tqdm import tqdm - -from haystack.preview import Document, component -from haystack.preview.dataclasses import ByteStream -from haystack.preview.lazy_imports import LazyImport - -with LazyImport("Run 'pip install markdown-it-py mdit_plain'") as markdown_conversion_imports: - from markdown_it import MarkdownIt - from mdit_plain.renderer import RendererPlain - - -logger = logging.getLogger(__name__) - - -@component -class MarkdownToDocument: - """ - Converts a Markdown file into a text Document. - - Usage example: - ```python - from haystack.preview.components.converters.markdown import MarkdownToDocument - - converter = MarkdownToDocument() - results = converter.run(sources=["sample.md"]) - documents = results["documents"] - print(documents[0].content) - # 'This is a text from the markdown file.' - ``` - """ - - def __init__(self, table_to_single_line: bool = False, progress_bar: bool = True): - """ - :param table_to_single_line: Convert contents of the table into a single line. Defaults to False. - :param progress_bar: Show a progress bar for the conversion. Defaults to True. - """ - markdown_conversion_imports.check() - - self.table_to_single_line = table_to_single_line - self.progress_bar = progress_bar - - @component.output_types(documents=List[Document]) - def run(self, sources: List[Union[str, Path, ByteStream]], meta: Optional[List[Dict[str, Any]]] = None): - """ - Reads text from a markdown file and executes optional preprocessing steps. - - :param sources: A list of markdown data sources (file paths or binary objects) - :param meta: Optional list of metadata to attach to the Documents. - The length of the list must match the number of paths. Defaults to `None`. - """ - parser = MarkdownIt(renderer_cls=RendererPlain) - if self.table_to_single_line: - parser.enable("table") - - documents = [] - if meta is None: - meta = [{}] * len(sources) - - for source, metadata in tqdm( - zip(sources, meta), - total=len(sources), - desc="Converting markdown files to Documents", - disable=not self.progress_bar, - ): - try: - file_content = self._extract_content(source) - except Exception as e: - logger.warning("Could not read %s. Skipping it. Error: %s", source, e) - continue - try: - text = parser.render(file_content) - except Exception as conversion_e: # Consider specifying the expected exception type(s) here - logger.warning("Failed to extract text from %s. Skipping it. Error: %s", source, conversion_e) - continue - - document = Document(content=text, meta=metadata) - documents.append(document) - - return {"documents": documents} - - def _extract_content(self, source: Union[str, Path, ByteStream]) -> str: - """ - Extracts content from the given data source. - :param source: The data source to extract content from. - :return: The extracted content. - """ - if isinstance(source, (str, Path)): - with open(source) as text_file: - return text_file.read() - if isinstance(source, ByteStream): - return source.data.decode("utf-8") - - raise ValueError(f"Unsupported source type: {type(source)}") diff --git a/haystack/preview/components/converters/pypdf.py b/haystack/preview/components/converters/pypdf.py deleted file mode 100644 index ede7da3816..0000000000 --- a/haystack/preview/components/converters/pypdf.py +++ /dev/null @@ -1,105 +0,0 @@ -import io -import logging -from typing import List, Union, Protocol, Dict -from pathlib import Path - -from haystack.preview.dataclasses import ByteStream -from haystack.preview.lazy_imports import LazyImport -from haystack.preview import Document, component, default_to_dict - -with LazyImport("Run 'pip install pypdf'") as pypdf_import: - from pypdf import PdfReader - - -logger = logging.getLogger(__name__) - - -class PyPDFConverter(Protocol): - """ - A protocol that defines a converter which takes a PdfReader object and converts it into a Document object. - """ - - def convert(self, reader: "PdfReader") -> Document: - ... - - -class DefaultConverter: - """ - The default converter class that extracts text from a PdfReader object's pages and returns a Document. - """ - - def convert(self, reader: "PdfReader") -> Document: - """Extract text from the PDF and return a Document object with the text content.""" - text = "".join(page.extract_text() for page in reader.pages if page.extract_text()) - return Document(content=text) - - -# This registry is used to store converters names and instances. -# It can be used to register custom converters. -CONVERTERS_REGISTRY: Dict[str, PyPDFConverter] = {"default": DefaultConverter()} - - -@component -class PyPDFToDocument: - """ - Converts PDF files to Document objects. - It uses a converter that follows the PyPDFConverter protocol to perform the conversion. - A default text extraction converter is used if no custom converter is provided. - """ - - def __init__(self, converter_name: str = "default"): - """ - Initializes the PyPDFToDocument component with an optional custom converter. - :param converter_name: A converter name that is registered in the CONVERTERS_REGISTRY. - Defaults to 'default'. - """ - pypdf_import.check() - - try: - converter = CONVERTERS_REGISTRY[converter_name] - except KeyError: - msg = ( - f"Invalid converter_name: {converter_name}.\n Available converters: {list(CONVERTERS_REGISTRY.keys())}" - ) - raise ValueError(msg) from KeyError - self.converter_name = converter_name - self._converter: PyPDFConverter = converter - - def to_dict(self): - # do not serialize the _converter instance - return default_to_dict(self, converter_name=self.converter_name) - - @component.output_types(documents=List[Document]) - def run(self, sources: List[Union[str, Path, ByteStream]]): - """ - Converts a list of PDF sources into Document objects using the configured converter. - - :param sources: A list of PDF data sources, which can be file paths or ByteStream objects. - :return: A dictionary containing a list of Document objects under the 'documents' key. - """ - documents = [] - for source in sources: - try: - pdf_reader = self._get_pdf_reader(source) - document = self._converter.convert(pdf_reader) - except Exception as e: - logger.warning("Could not read %s and convert it to Document, skipping. %s", source, e) - continue - documents.append(document) - - return {"documents": documents} - - def _get_pdf_reader(self, source: Union[str, Path, ByteStream]) -> "PdfReader": - """ - Creates a PdfReader object from a given source, which can be a file path or a ByteStream object. - - :param source: The source of the PDF data. - :return: A PdfReader instance initialized with the PDF data from the source. - :raises ValueError: If the source type is not supported. - """ - if isinstance(source, (str, Path)): - return PdfReader(str(source)) - elif isinstance(source, ByteStream): - return PdfReader(io.BytesIO(source.data)) - else: - raise ValueError(f"Unsupported source type: {type(source)}") diff --git a/haystack/preview/components/converters/tika.py b/haystack/preview/components/converters/tika.py deleted file mode 100644 index 89f7c217eb..0000000000 --- a/haystack/preview/components/converters/tika.py +++ /dev/null @@ -1,54 +0,0 @@ -import logging -from pathlib import Path -from typing import List, Union - -from haystack.preview.lazy_imports import LazyImport -from haystack.preview import component, Document - - -with LazyImport("Run 'pip install tika'") as tika_import: - from tika import parser as tika_parser - -logger = logging.getLogger(__name__) - - -@component -class TikaDocumentConverter: - """ - A component for converting files of different types (pdf, docx, html, etc.) to Documents. - This component uses [Apache Tika](https://tika.apache.org/) for parsing the files and, therefore, - requires a running Tika server. - """ - - def __init__(self, tika_url: str = "http://localhost:9998/tika"): - """ - Create a TikaDocumentConverter component. - - :param tika_url: URL of the Tika server. Default: `"http://localhost:9998/tika"` - """ - tika_import.check() - self.tika_url = tika_url - - @component.output_types(documents=List[Document]) - def run(self, paths: List[Union[str, Path]]): - """ - Convert files to Documents. - - :param paths: A list of paths to the files to convert. - """ - - documents = [] - for path in paths: - path = Path(path) - try: - parsed_file = tika_parser.from_file(path.as_posix(), self.tika_url) - extracted_text = parsed_file["content"] - if not extracted_text: - logger.warning("Skipping file at '%s' as Tika was not able to extract any content.", str(path)) - continue - document = Document(content=extracted_text) - documents.append(document) - except Exception as e: - logger.error("Could not convert file at '%s' to Document. Error: %s", str(path), e) - - return {"documents": documents} diff --git a/haystack/preview/components/converters/txt.py b/haystack/preview/components/converters/txt.py deleted file mode 100644 index 2e63f72861..0000000000 --- a/haystack/preview/components/converters/txt.py +++ /dev/null @@ -1,56 +0,0 @@ -import logging -from pathlib import Path -from typing import List, Union - -from haystack.preview import Document, component -from haystack.preview.dataclasses import ByteStream - - -logger = logging.getLogger(__name__) - - -@component -class TextFileToDocument: - """ - A component for converting a text file to a Document. - """ - - def __init__(self, encoding: str = "utf-8"): - """ - Create a TextFileToDocument component. - - :param encoding: The default encoding of the text files. Default: `"utf-8"`. - Note that if the encoding is specified in the metadata of a ByteStream, - it will override this default. - """ - self.encoding = encoding - - @component.output_types(documents=List[Document]) - def run(self, sources: List[Union[str, Path, ByteStream]]): - """ - Convert text files to Documents. - - :param streams: A list of paths to text files or ByteStream objects. - Note that if an encoding is specified in the metadata of a ByteStream, - it will override the component's default. - :return: A dictionary containing the converted documents. - """ - documents = [] - for source in sources: - if isinstance(source, (Path, str)): - try: - path = source - source = ByteStream.from_file_path(Path(source)) - source.metadata["file_path"] = str(path) - except Exception as e: - logger.warning("Could not convert file %s. Skipping it. Error message: %s", source, e) - continue - try: - encoding = source.metadata.get("encoding", self.encoding) - document = Document(content=source.data.decode(encoding)) - document.meta = source.metadata - documents.append(document) - except Exception as e: - logger.warning("Could not convert file %s. Skipping it. Error message: %s", source, e) - - return {"documents": documents} diff --git a/haystack/preview/components/embedders/__init__.py b/haystack/preview/components/embedders/__init__.py deleted file mode 100644 index a0840d7e0a..0000000000 --- a/haystack/preview/components/embedders/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -from haystack.preview.components.embedders.sentence_transformers_text_embedder import SentenceTransformersTextEmbedder -from haystack.preview.components.embedders.sentence_transformers_document_embedder import ( - SentenceTransformersDocumentEmbedder, -) -from haystack.preview.components.embedders.openai_document_embedder import OpenAIDocumentEmbedder -from haystack.preview.components.embedders.openai_text_embedder import OpenAITextEmbedder - -__all__ = [ - "SentenceTransformersTextEmbedder", - "SentenceTransformersDocumentEmbedder", - "OpenAITextEmbedder", - "OpenAIDocumentEmbedder", -] diff --git a/haystack/preview/components/embedders/backends/__init__.py b/haystack/preview/components/embedders/backends/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/haystack/preview/components/embedders/backends/sentence_transformers_backend.py b/haystack/preview/components/embedders/backends/sentence_transformers_backend.py deleted file mode 100644 index 8883c235a4..0000000000 --- a/haystack/preview/components/embedders/backends/sentence_transformers_backend.py +++ /dev/null @@ -1,46 +0,0 @@ -from typing import List, Optional, Union, Dict - -from haystack.preview.lazy_imports import LazyImport - -with LazyImport(message="Run 'pip install \"sentence-transformers>=2.2.0\"'") as sentence_transformers_import: - from sentence_transformers import SentenceTransformer - - -class _SentenceTransformersEmbeddingBackendFactory: - """ - Factory class to create instances of Sentence Transformers embedding backends. - """ - - _instances: Dict[str, "_SentenceTransformersEmbeddingBackend"] = {} - - @staticmethod - def get_embedding_backend( - model_name_or_path: str, device: Optional[str] = None, use_auth_token: Union[bool, str, None] = None - ): - embedding_backend_id = f"{model_name_or_path}{device}{use_auth_token}" - - if embedding_backend_id in _SentenceTransformersEmbeddingBackendFactory._instances: - return _SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] - embedding_backend = _SentenceTransformersEmbeddingBackend( - model_name_or_path=model_name_or_path, device=device, use_auth_token=use_auth_token - ) - _SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend - return embedding_backend - - -class _SentenceTransformersEmbeddingBackend: - """ - Class to manage Sentence Transformers embeddings. - """ - - def __init__( - self, model_name_or_path: str, device: Optional[str] = None, use_auth_token: Union[bool, str, None] = None - ): - sentence_transformers_import.check() - self.model = SentenceTransformer( - model_name_or_path=model_name_or_path, device=device, use_auth_token=use_auth_token - ) - - def embed(self, data: List[str], **kwargs) -> List[List[float]]: - embeddings = self.model.encode(data, **kwargs).tolist() - return embeddings diff --git a/haystack/preview/components/embedders/openai_document_embedder.py b/haystack/preview/components/embedders/openai_document_embedder.py deleted file mode 100644 index 1a2f1e1f18..0000000000 --- a/haystack/preview/components/embedders/openai_document_embedder.py +++ /dev/null @@ -1,176 +0,0 @@ -from typing import List, Optional, Dict, Any, Tuple -import os - -import openai -from tqdm import tqdm - - -from haystack.preview import component, Document, default_to_dict - - -@component -class OpenAIDocumentEmbedder: - """ - A component for computing Document embeddings using OpenAI models. - The embedding of each Document is stored in the `embedding` field of the Document. - - Usage example: - ```python - from haystack.preview import Document - from haystack.preview.components.embedders import OpenAIDocumentEmbedder - - doc = Document(text="I love pizza!") - - document_embedder = OpenAIDocumentEmbedder() - - result = document_embedder.run([doc]) - print(result['documents'][0].embedding) - - # [0.017020374536514282, -0.023255806416273117, ...] - ``` - """ - - def __init__( - self, - api_key: Optional[str] = None, - model_name: str = "text-embedding-ada-002", - organization: Optional[str] = None, - prefix: str = "", - suffix: str = "", - batch_size: int = 32, - progress_bar: bool = True, - metadata_fields_to_embed: Optional[List[str]] = None, - embedding_separator: str = "\n", - ): - """ - Create a OpenAIDocumentEmbedder component. - :param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the - environment variable OPENAI_API_KEY (recommended). - :param model_name: The name of the model to use. - :param api_base_url: The OpenAI API Base url, defaults to `https://api.openai.com/v1`. - :param organization: The OpenAI-Organization ID, defaults to `None`. For more details, see OpenAI - [documentation](https://platform.openai.com/docs/api-reference/requesting-organization). - :param prefix: A string to add to the beginning of each text. - :param suffix: A string to add to the end of each text. - :param batch_size: Number of Documents to encode at once. - :param progress_bar: Whether to show a progress bar or not. Can be helpful to disable in production deployments - to keep the logs clean. - :param metadata_fields_to_embed: List of meta fields that should be embedded along with the Document text. - :param embedding_separator: Separator used to concatenate the meta fields to the Document text. - """ - # if the user does not provide the API key, check if it is set in the module client - api_key = api_key or openai.api_key - if api_key is None: - try: - api_key = os.environ["OPENAI_API_KEY"] - except KeyError as e: - raise ValueError( - "OpenAIDocumentEmbedder expects an OpenAI API key. " - "Set the OPENAI_API_KEY environment variable (recommended) or pass it explicitly." - ) from e - - self.model_name = model_name - self.organization = organization - self.prefix = prefix - self.suffix = suffix - self.batch_size = batch_size - self.progress_bar = progress_bar - self.metadata_fields_to_embed = metadata_fields_to_embed or [] - self.embedding_separator = embedding_separator - - openai.api_key = api_key - if organization is not None: - openai.organization = organization - - def _get_telemetry_data(self) -> Dict[str, Any]: - """ - Data that is sent to Posthog for usage analytics. - """ - return {"model": self.model_name} - - def to_dict(self) -> Dict[str, Any]: - """ - This method overrides the default serializer in order to avoid leaking the `api_key` value passed - to the constructor. - """ - return default_to_dict( - self, - model_name=self.model_name, - organization=self.organization, - prefix=self.prefix, - suffix=self.suffix, - batch_size=self.batch_size, - progress_bar=self.progress_bar, - metadata_fields_to_embed=self.metadata_fields_to_embed, - embedding_separator=self.embedding_separator, - ) - - def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: - """ - Prepare the texts to embed by concatenating the Document text with the metadata fields to embed. - """ - texts_to_embed = [] - for doc in documents: - meta_values_to_embed = [ - str(doc.meta[key]) - for key in self.metadata_fields_to_embed - if key in doc.meta and doc.meta[key] is not None - ] - - text_to_embed = ( - self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + self.suffix - ) - - # copied from OpenAI embedding_utils (https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py) - # replace newlines, which can negatively affect performance. - text_to_embed = text_to_embed.replace("\n", " ") - texts_to_embed.append(text_to_embed) - return texts_to_embed - - def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List[List[float]], Dict[str, Any]]: - """ - Embed a list of texts in batches. - """ - - all_embeddings = [] - metadata = {} - for i in tqdm( - range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings" - ): - batch = texts_to_embed[i : i + batch_size] - response = openai.Embedding.create(model=self.model_name, input=batch) - embeddings = [el["embedding"] for el in response.data] - all_embeddings.extend(embeddings) - - if "model" not in metadata: - metadata["model"] = response.model - if "usage" not in metadata: - metadata["usage"] = dict(response.usage.items()) - else: - metadata["usage"]["prompt_tokens"] += response.usage.prompt_tokens - metadata["usage"]["total_tokens"] += response.usage.total_tokens - - return all_embeddings, metadata - - @component.output_types(documents=List[Document], metadata=Dict[str, Any]) - def run(self, documents: List[Document]): - """ - Embed a list of Documents. - The embedding of each Document is stored in the `embedding` field of the Document. - - :param documents: A list of Documents to embed. - """ - if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): - raise TypeError( - "OpenAIDocumentEmbedder expects a list of Documents as input." - "In case you want to embed a string, please use the OpenAITextEmbedder." - ) - - texts_to_embed = self._prepare_texts_to_embed(documents=documents) - - embeddings, metadata = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size) - - for doc, emb in zip(documents, embeddings): - doc.embedding = emb - - return {"documents": documents, "metadata": metadata} diff --git a/haystack/preview/components/embedders/openai_text_embedder.py b/haystack/preview/components/embedders/openai_text_embedder.py deleted file mode 100644 index 7be6065a83..0000000000 --- a/haystack/preview/components/embedders/openai_text_embedder.py +++ /dev/null @@ -1,106 +0,0 @@ -from typing import List, Optional, Dict, Any -import os - -import openai - -from haystack.preview import component, default_to_dict - - -@component -class OpenAITextEmbedder: - """ - A component for embedding strings using OpenAI models. - - Usage example: - ```python - from haystack.preview.components.embedders import OpenAITextEmbedder - - text_to_embed = "I love pizza!" - - text_embedder = OpenAITextEmbedder() - - print(text_embedder.run(text_to_embed)) - - # {'embedding': [0.017020374536514282, -0.023255806416273117, ...], - # 'metadata': {'model': 'text-embedding-ada-002-v2', - # 'usage': {'prompt_tokens': 4, 'total_tokens': 4}}} - ``` - """ - - def __init__( - self, - api_key: Optional[str] = None, - model_name: str = "text-embedding-ada-002", - organization: Optional[str] = None, - prefix: str = "", - suffix: str = "", - ): - """ - Create an OpenAITextEmbedder component. - - :param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the - environment variable OPENAI_API_KEY (recommended). - :param model_name: The name of the OpenAI model to use. For more details on the available models, - see [OpenAI documentation](https://platform.openai.com/docs/guides/embeddings/embedding-models). - :param organization: The OpenAI-Organization ID, defaults to `None`. For more details, - see [OpenAI documentation](https://platform.openai.com/docs/api-reference/requesting-organization). - :param prefix: A string to add to the beginning of each text. - :param suffix: A string to add to the end of each text. - """ - # if the user does not provide the API key, check if it is set in the module client - api_key = api_key or openai.api_key - if api_key is None: - try: - api_key = os.environ["OPENAI_API_KEY"] - except KeyError as e: - raise ValueError( - "OpenAITextEmbedder expects an OpenAI API key. " - "Set the OPENAI_API_KEY environment variable (recommended) or pass it explicitly." - ) from e - - self.model_name = model_name - self.organization = organization - self.prefix = prefix - self.suffix = suffix - - openai.api_key = api_key - if organization is not None: - openai.organization = organization - - def _get_telemetry_data(self) -> Dict[str, Any]: - """ - Data that is sent to Posthog for usage analytics. - """ - return {"model": self.model_name} - - def to_dict(self) -> Dict[str, Any]: - """ - This method overrides the default serializer in order to avoid leaking the `api_key` value passed - to the constructor. - """ - - return default_to_dict( - self, model_name=self.model_name, organization=self.organization, prefix=self.prefix, suffix=self.suffix - ) - - @component.output_types(embedding=List[float], metadata=Dict[str, Any]) - def run(self, text: str): - """Embed a string.""" - if not isinstance(text, str): - raise TypeError( - "OpenAITextEmbedder expects a string as an input." - "In case you want to embed a list of Documents, please use the OpenAIDocumentEmbedder." - ) - - text_to_embed = self.prefix + text + self.suffix - - # copied from OpenAI embedding_utils (https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py) - # replace newlines, which can negatively affect performance. - text_to_embed = text_to_embed.replace("\n", " ") - - response = openai.Embedding.create(model=self.model_name, input=text_to_embed) - - metadata = {"model": response.model, "usage": dict(response.usage.items())} - embedding = response.data[0]["embedding"] - - return {"embedding": embedding, "metadata": metadata} diff --git a/haystack/preview/components/embedders/sentence_transformers_document_embedder.py b/haystack/preview/components/embedders/sentence_transformers_document_embedder.py deleted file mode 100644 index 7b7a5ca183..0000000000 --- a/haystack/preview/components/embedders/sentence_transformers_document_embedder.py +++ /dev/null @@ -1,145 +0,0 @@ -from typing import List, Optional, Union, Dict, Any - -from haystack.preview import component, Document, default_to_dict -from haystack.preview.components.embedders.backends.sentence_transformers_backend import ( - _SentenceTransformersEmbeddingBackendFactory, -) - - -@component -class SentenceTransformersDocumentEmbedder: - """ - A component for computing Document embeddings using Sentence Transformers models. - The embedding of each Document is stored in the `embedding` field of the Document. - - Usage example: - ```python - from haystack.preview import Document - from haystack.preview.components.embedders import SentenceTransformersDocumentEmbedder - doc = Document(text="I love pizza!") - doc_embedder = SentenceTransformersDocumentEmbedder() - doc_embedder.warm_up() - - result = doc_embedder.run([doc]) - print(result['documents'][0].embedding) - - # [-0.07804739475250244, 0.1498992145061493, ...] - ``` - """ - - def __init__( - self, - model_name_or_path: str = "sentence-transformers/all-mpnet-base-v2", - device: Optional[str] = None, - token: Union[bool, str, None] = None, - prefix: str = "", - suffix: str = "", - batch_size: int = 32, - progress_bar: bool = True, - normalize_embeddings: bool = False, - metadata_fields_to_embed: Optional[List[str]] = None, - embedding_separator: str = "\n", - ): - """ - Create a SentenceTransformersDocumentEmbedder component. - - :param model_name_or_path: Local path or name of the model in Hugging Face's model hub, - such as ``'sentence-transformers/all-mpnet-base-v2'``. - :param device: Device (like 'cuda' / 'cpu') that should be used for computation. - Defaults to CPU. - :param token: The API token used to download private models from Hugging Face. - If this parameter is set to `True`, then the token generated when running - `transformers-cli login` (stored in ~/.huggingface) will be used. - :param prefix: A string to add to the beginning of each Document text before embedding. - Can be used to prepend the text with an instruction, as required by some embedding models, - such as E5 and bge. - :param suffix: A string to add to the end of each Document text before embedding. - :param batch_size: Number of strings to encode at once. - :param progress_bar: If true, displays progress bar during embedding. - :param normalize_embeddings: If set to true, returned vectors will have length 1. - :param metadata_fields_to_embed: List of meta fields that should be embedded along with the Document content. - :param embedding_separator: Separator used to concatenate the meta fields to the Document content. - """ - - self.model_name_or_path = model_name_or_path - # TODO: remove device parameter and use Haystack's device management once migrated - self.device = device or "cpu" - self.token = token - self.prefix = prefix - self.suffix = suffix - self.batch_size = batch_size - self.progress_bar = progress_bar - self.normalize_embeddings = normalize_embeddings - self.metadata_fields_to_embed = metadata_fields_to_embed or [] - self.embedding_separator = embedding_separator - - def _get_telemetry_data(self) -> Dict[str, Any]: - """ - Data that is sent to Posthog for usage analytics. - """ - return {"model": self.model_name_or_path} - - def to_dict(self) -> Dict[str, Any]: - """ - Serialize this component to a dictionary. - """ - return default_to_dict( - self, - model_name_or_path=self.model_name_or_path, - device=self.device, - token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens - prefix=self.prefix, - suffix=self.suffix, - batch_size=self.batch_size, - progress_bar=self.progress_bar, - normalize_embeddings=self.normalize_embeddings, - metadata_fields_to_embed=self.metadata_fields_to_embed, - embedding_separator=self.embedding_separator, - ) - - def warm_up(self): - """ - Load the embedding backend. - """ - if not hasattr(self, "embedding_backend"): - self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend( - model_name_or_path=self.model_name_or_path, device=self.device, use_auth_token=self.token - ) - - @component.output_types(documents=List[Document]) - def run(self, documents: List[Document]): - """ - Embed a list of Documents. - The embedding of each Document is stored in the `embedding` field of the Document. - """ - if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): - raise TypeError( - "SentenceTransformersDocumentEmbedder expects a list of Documents as input." - "In case you want to embed a list of strings, please use the SentenceTransformersTextEmbedder." - ) - if not hasattr(self, "embedding_backend"): - raise RuntimeError("The embedding model has not been loaded. Please call warm_up() before running.") - - # TODO: once non textual Documents are properly supported, we should also prepare them for embedding here - - texts_to_embed = [] - for doc in documents: - meta_values_to_embed = [ - str(doc.meta[key]) for key in self.metadata_fields_to_embed if key in doc.meta and doc.meta[key] - ] - text_to_embed = ( - self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + self.suffix - ) - texts_to_embed.append(text_to_embed) - - embeddings = self.embedding_backend.embed( - texts_to_embed, - batch_size=self.batch_size, - show_progress_bar=self.progress_bar, - normalize_embeddings=self.normalize_embeddings, - ) - - for doc, emb in zip(documents, embeddings): - doc.embedding = emb - - return {"documents": documents} diff --git a/haystack/preview/components/embedders/sentence_transformers_text_embedder.py b/haystack/preview/components/embedders/sentence_transformers_text_embedder.py deleted file mode 100644 index badb3997ee..0000000000 --- a/haystack/preview/components/embedders/sentence_transformers_text_embedder.py +++ /dev/null @@ -1,118 +0,0 @@ -from typing import List, Optional, Union, Dict, Any - -from haystack.preview import component, default_to_dict -from haystack.preview.components.embedders.backends.sentence_transformers_backend import ( - _SentenceTransformersEmbeddingBackendFactory, -) - - -@component -class SentenceTransformersTextEmbedder: - """ - A component for embedding strings using Sentence Transformers models. - - Usage example: - ```python - from haystack.preview.components.embedders import SentenceTransformersTextEmbedder - - text_to_embed = "I love pizza!" - - text_embedder = SentenceTransformersTextEmbedder() - text_embedder.warm_up() - - print(text_embedder.run(text_to_embed)) - - # {'embedding': [-0.07804739475250244, 0.1498992145061493,, ...]} - ``` - """ - - def __init__( - self, - model_name_or_path: str = "sentence-transformers/all-mpnet-base-v2", - device: Optional[str] = None, - token: Union[bool, str, None] = None, - prefix: str = "", - suffix: str = "", - batch_size: int = 32, - progress_bar: bool = True, - normalize_embeddings: bool = False, - ): - """ - Create a SentenceTransformersTextEmbedder component. - - :param model_name_or_path: Local path or name of the model in Hugging Face's model hub, - such as ``'sentence-transformers/all-mpnet-base-v2'``. - :param device: Device (like 'cuda' / 'cpu') that should be used for computation. - Defaults to CPU. - :param token: The API token used to download private models from Hugging Face. - If this parameter is set to `True`, then the token generated when running - `transformers-cli login` (stored in ~/.huggingface) will be used. - :param prefix: A string to add to the beginning of each Document text before embedding. - Can be used to prepend the text with an instruction, as required by some embedding models, - such as E5 and bge. - :param suffix: A string to add to the end of each text. - :param batch_size: Number of strings to encode at once. - :param progress_bar: If true, displays progress bar during embedding. - :param normalize_embeddings: If set to true, returned vectors will have length 1. - """ - - self.model_name_or_path = model_name_or_path - # TODO: remove device parameter and use Haystack's device management once migrated - self.device = device or "cpu" - self.token = token - self.prefix = prefix - self.suffix = suffix - self.batch_size = batch_size - self.progress_bar = progress_bar - self.normalize_embeddings = normalize_embeddings - - def _get_telemetry_data(self) -> Dict[str, Any]: - """ - Data that is sent to Posthog for usage analytics. - """ - return {"model": self.model_name_or_path} - - def to_dict(self) -> Dict[str, Any]: - """ - Serialize this component to a dictionary. - """ - return default_to_dict( - self, - model_name_or_path=self.model_name_or_path, - device=self.device, - token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens - prefix=self.prefix, - suffix=self.suffix, - batch_size=self.batch_size, - progress_bar=self.progress_bar, - normalize_embeddings=self.normalize_embeddings, - ) - - def warm_up(self): - """ - Load the embedding backend. - """ - if not hasattr(self, "embedding_backend"): - self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend( - model_name_or_path=self.model_name_or_path, device=self.device, use_auth_token=self.token - ) - - @component.output_types(embedding=List[float]) - def run(self, text: str): - """Embed a string.""" - if not isinstance(text, str): - raise TypeError( - "SentenceTransformersTextEmbedder expects a string as input." - "In case you want to embed a list of Documents, please use the SentenceTransformersDocumentEmbedder." - ) - if not hasattr(self, "embedding_backend"): - raise RuntimeError("The embedding model has not been loaded. Please call warm_up() before running.") - - text_to_embed = self.prefix + text + self.suffix - embedding = self.embedding_backend.embed( - [text_to_embed], - batch_size=self.batch_size, - show_progress_bar=self.progress_bar, - normalize_embeddings=self.normalize_embeddings, - )[0] - return {"embedding": embedding} diff --git a/haystack/preview/components/fetchers/__init__.py b/haystack/preview/components/fetchers/__init__.py deleted file mode 100644 index f0580d369a..0000000000 --- a/haystack/preview/components/fetchers/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from haystack.preview.components.fetchers.link_content import LinkContentFetcher - -__all__ = ["LinkContentFetcher"] diff --git a/haystack/preview/components/fetchers/link_content.py b/haystack/preview/components/fetchers/link_content.py deleted file mode 100644 index 664f716291..0000000000 --- a/haystack/preview/components/fetchers/link_content.py +++ /dev/null @@ -1,203 +0,0 @@ -import logging -from collections import defaultdict -from concurrent.futures import ThreadPoolExecutor -from typing import Callable, Dict, List, Optional, Tuple - -import requests -from requests import Response -from requests.exceptions import HTTPError -from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt, wait_exponential - -from haystack.preview import component -from haystack.preview.dataclasses import ByteStream -from haystack.preview.version import __version__ - -logger = logging.getLogger(__name__) - - -DEFAULT_USER_AGENT = f"haystack/LinkContentFetcher/{__version__}" - -REQUEST_HEADERS = { - "accept": "*/*", - "User-Agent": DEFAULT_USER_AGENT, - "Accept-Language": "en-US,en;q=0.9,it;q=0.8,es;q=0.7", - "referer": "https://www.google.com/", -} - - -def text_content_handler(response: Response) -> ByteStream: - """ - :param response: Response object from the request. - :return: The extracted text. - """ - return ByteStream.from_string(response.text) - - -def binary_content_handler(response: Response) -> ByteStream: - """ - :param response: Response object from the request. - :return: The extracted binary file-like object. - """ - return ByteStream(data=response.content) - - -@component -class LinkContentFetcher: - """ - LinkContentFetcher is a component for fetching and extracting content from URLs. It supports handling various - content types, retries on failures, and automatic user-agent rotation for failed web requests. - """ - - def __init__( - self, - raise_on_failure: bool = True, - user_agents: Optional[List[str]] = None, - retry_attempts: int = 2, - timeout: int = 3, - ): - """ - Initializes a LinkContentFetcher instance. - - :param raise_on_failure: If True, raises an exception if it fails to fetch a single URL. - For multiple URLs, it logs errors and returns the content it successfully fetched. Default is True. - :param user_agents: A list of user agents for fetching content. If None, a default user agent is used. - :param retry_attempts: Specifies how many times you want it to retry to fetch the URL's content. Default is 2. - :param timeout: Timeout in seconds for the request. Default is 3. - """ - self.raise_on_failure = raise_on_failure - self.user_agents = user_agents or [DEFAULT_USER_AGENT] - self.current_user_agent_idx: int = 0 - self.retry_attempts = retry_attempts - self.timeout = timeout - - # register default content handlers that extract data from the response - self.handlers: Dict[str, Callable[[Response], ByteStream]] = defaultdict(lambda: text_content_handler) - self.handlers["text/html"] = text_content_handler - self.handlers["text/plain"] = text_content_handler - self.handlers["application/pdf"] = binary_content_handler - self.handlers["application/octet-stream"] = binary_content_handler - - @retry( - reraise=True, - stop=stop_after_attempt(self.retry_attempts), - wait=wait_exponential(multiplier=1, min=2, max=10), - retry=(retry_if_exception_type((HTTPError, requests.RequestException))), - # This method is invoked only after failed requests (exception raised) - after=self._switch_user_agent, - ) - def get_response(url): - # we need to copy because we modify the headers - headers = REQUEST_HEADERS.copy() - headers["User-Agent"] = self.user_agents[self.current_user_agent_idx] - response = requests.get(url, headers=headers, timeout=timeout or 3) - response.raise_for_status() - return response - - self._get_response: Callable = get_response - - @component.output_types(streams=List[ByteStream]) - def run(self, urls: List[str]): - """ - Fetches content from a list of URLs and returns a list of extracted content streams. - Each content stream is a ByteStream object containing the extracted content as binary data. - Each ByteStream object in the returned list corresponds to the contents of a single URL. - The content type of each stream is stored in the metadata of the ByteStream object under - the key "content_type". The URL of the fetched content is stored under the key "url". - - :param urls: A list of URLs to fetch content from. - :return: A lists of ByteStream objects representing the extracted content. - - :raises: If the provided list of URLs contains only a single URL, and `raise_on_failure` is set to True, - an exception will be raised in case of an error during content retrieval. In all other scenarios, any - retrieval errors are logged, and a list of successfully retrieved ByteStream objects is returned. - """ - streams: List[ByteStream] = [] - if not urls: - return {"streams": streams} - - # don't use multithreading if there's only one URL - if len(urls) == 1: - stream_metadata, stream = self.fetch(urls[0]) - stream.metadata.update(stream_metadata) - streams.append(stream) - else: - with ThreadPoolExecutor() as executor: - results = executor.map(self._fetch_with_exception_suppression, urls) - - for stream_metadata, stream in results: # type: ignore - if stream_metadata is not None and stream is not None: - stream.metadata.update(stream_metadata) - streams.append(stream) - - return {"streams": streams} - - def fetch(self, url: str) -> Tuple[Dict[str, str], ByteStream]: - """ - Fetches content from a URL and returns it as a ByteStream. - - :param url: The URL to fetch content from. - :return: A tuple containing the ByteStream metadata dict and the corresponding ByteStream. - ByteStream metadata contains the URL and the content type of the fetched content. - The content type is a string indicating the type of content fetched (for example, "text/html", "application/pdf"). - The ByteStream object contains the fetched content as binary data. - - :raises: If an error occurs during content retrieval and `raise_on_failure` is set to True, this method will - raise an exception. Otherwise, all fetching errors are logged, and an empty ByteStream is returned. - - """ - content_type: str = "text/html" - stream: ByteStream = ByteStream(data=b"") - try: - response = self._get_response(url) - content_type = self._get_content_type(response) - handler: Callable = self.handlers[content_type] - stream = handler(response) - except Exception as e: - if self.raise_on_failure: - raise e - # less verbose log as this is expected to happen often (requests failing, blocked, etc.) - logger.debug("Couldn't retrieve content from %s because %s", url, str(e)) - - finally: - self.current_user_agent_idx = 0 - - return {"content_type": content_type, "url": url}, stream - - def _fetch_with_exception_suppression(self, url: str) -> Tuple[Optional[Dict[str, str]], Optional[ByteStream]]: - """ - Fetches content from a URL and returns it as a ByteStream. - - If `raise_on_failure` is set to True, this method will wrap the fetch() method and catch any exceptions. - Otherwise, it will simply call the fetch() method. - :param url: The URL to fetch content from. - :return: A tuple containing the ByteStream metadata dict and the corresponding ByteStream. - - """ - if self.raise_on_failure: - try: - return self.fetch(url) - except Exception as e: - logger.warning("Error fetching %s: %s", url, str(e)) - return {"content_type": "Unknown", "url": url}, None - else: - return self.fetch(url) - - def _get_content_type(self, response: Response): - """ - Get the content type of the response. - - :param response: The response object. - :return: The content type of the response. - """ - content_type = response.headers.get("Content-Type", "") - return content_type.split(";")[0] - - def _switch_user_agent(self, retry_state: RetryCallState) -> None: - """ - Switches the User-Agent for this LinkContentRetriever to the next one in the list of user agents. - Used by tenacity to retry the requests with a different user agent. - - :param retry_state: The retry state (unused, required by tenacity). - """ - self.current_user_agent_idx = (self.current_user_agent_idx + 1) % len(self.user_agents) - logger.debug("Switched user agent to %s", self.user_agents[self.current_user_agent_idx]) diff --git a/haystack/preview/components/generators/__init__.py b/haystack/preview/components/generators/__init__.py deleted file mode 100644 index 037ca7b7a5..0000000000 --- a/haystack/preview/components/generators/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from haystack.preview.components.generators.cohere import CohereGenerator -from haystack.preview.components.generators.hugging_face_local import HuggingFaceLocalGenerator -from haystack.preview.components.generators.hugging_face_tgi import HuggingFaceTGIGenerator -from haystack.preview.components.generators.openai import GPTGenerator - -__all__ = ["HuggingFaceLocalGenerator", "HuggingFaceTGIGenerator", "GPTGenerator", "CohereGenerator"] diff --git a/haystack/preview/components/generators/chat/__init__.py b/haystack/preview/components/generators/chat/__init__.py deleted file mode 100644 index 2126529d83..0000000000 --- a/haystack/preview/components/generators/chat/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from haystack.preview.components.generators.chat.hugging_face_tgi import HuggingFaceTGIChatGenerator -from haystack.preview.components.generators.chat.openai import GPTChatGenerator - -__all__ = ["HuggingFaceTGIChatGenerator", "GPTChatGenerator"] diff --git a/haystack/preview/components/generators/chat/hugging_face_tgi.py b/haystack/preview/components/generators/chat/hugging_face_tgi.py deleted file mode 100644 index 4d4062ef24..0000000000 --- a/haystack/preview/components/generators/chat/hugging_face_tgi.py +++ /dev/null @@ -1,280 +0,0 @@ -import logging -from dataclasses import asdict -from typing import Any, Dict, List, Optional, Iterable, Callable -from urllib.parse import urlparse - -from haystack.preview import component, default_to_dict, default_from_dict -from haystack.preview.components.generators.utils import serialize_callback_handler, deserialize_callback_handler -from haystack.preview.dataclasses import ChatMessage, StreamingChunk -from haystack.preview.components.generators.hf_utils import check_valid_model, check_generation_params -from haystack.preview.lazy_imports import LazyImport - -with LazyImport(message="Run 'pip install transformers'") as transformers_import: - from huggingface_hub import InferenceClient - from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, TextGenerationResponse, Token - from transformers import AutoTokenizer - -logger = logging.getLogger(__name__) - - -class HuggingFaceTGIChatGenerator: - """ - Enables text generation using HuggingFace Hub hosted chat-based LLMs. This component is designed to seamlessly - inference chat-based models deployed on the Text Generation Inference (TGI) backend. - - You can use this component for chat LLMs hosted on Hugging Face inference endpoints, the rate-limited - Inference API tier: - - ```python - from haystack.preview.components.generators.chat import HuggingFaceTGIChatGenerator - from haystack.preview.dataclasses import ChatMessage - - messages = [ChatMessage.from_system("\nYou are a helpful, respectful and honest assistant"), - ChatMessage.from_user("What's Natural Language Processing?")] - - - client = HuggingFaceTGIChatGenerator(model="meta-llama/Llama-2-70b-chat-hf", token="") - client.warm_up() - response = client.run(messages, generation_kwargs={"max_new_tokens": 120}) - print(response) - ``` - - For chat LLMs hosted on paid https://huggingface.co/inference-endpoints endpoint and/or your own custom TGI - endpoint, you'll need to provide the URL of the endpoint as well as a valid token: - - ```python - from haystack.preview.components.generators.chat import HuggingFaceTGIChatGenerator - from haystack.preview.dataclasses import ChatMessage - - messages = [ChatMessage.from_system("\nYou are a helpful, respectful and honest assistant"), - ChatMessage.from_user("What's Natural Language Processing?")] - - client = HuggingFaceTGIChatGenerator(model="meta-llama/Llama-2-70b-chat-hf", - url="", - token="") - client.warm_up() - response = client.run(messages, generation_kwargs={"max_new_tokens": 120}) - print(response) - ``` - - Key Features and Compatibility: - - **Primary Compatibility**: Designed to work seamlessly with any chat-based model deployed using the TGI - framework. For more information on TGI, visit https://github.com/huggingface/text-generation-inference. - - **Hugging Face Inference Endpoints**: Supports inference of TGI chat LLMs deployed on Hugging Face - inference endpoints. For more details, refer to https://huggingface.co/inference-endpoints. - - **Inference API Support**: Supports inference of TGI chat LLMs hosted on the rate-limited Inference - API tier. Learn more about the Inference API at https://huggingface.co/inference-api. - Discover available chat models using the following command: - ``` - wget -qO- https://api-inference.huggingface.co/framework/text-generation-inference | grep chat - ``` - and simply use the model ID as the model parameter for this component. You'll also need to provide a valid - Hugging Face API token as the token parameter. - - **Custom TGI Endpoints**: Supports inference of TGI chat LLMs deployed on custom TGI endpoints. Anyone can - deploy their own TGI endpoint using the TGI framework. For more details, refer - to https://huggingface.co/inference-endpoints. - - Input and Output Format: - - **ChatMessage Format**: This component uses the ChatMessage format to structure both input and output, - ensuring coherent and contextually relevant responses in chat-based text generation scenarios. Details on the - ChatMessage format can be found at https://github.com/openai/openai-python/blob/main/chatml.md. - - """ - - def __init__( - self, - model: str = "meta-llama/Llama-2-13b-chat-hf", - url: Optional[str] = None, - token: Optional[str] = None, - chat_template: Optional[str] = None, - generation_kwargs: Optional[Dict[str, Any]] = None, - stop_words: Optional[List[str]] = None, - streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, - ): - """ - Initialize the HuggingFaceTGIChatGenerator instance. - - :param model: A string representing the model path or URL. Default is "meta-llama/Llama-2-13b-chat-hf". - :param url: An optional string representing the URL of the TGI endpoint. - :param chat_template: This optional parameter allows you to specify a Jinja template for formatting chat - messages. While high-quality and well-supported chat models typically include their own chat templates - accessible through their tokenizer, there are models that do not offer this feature. For such scenarios, - or if you wish to use a custom template instead of the model's default, you can use this parameter to - set your preferred chat template. - :param token: The Hugging Face token for HTTP bearer authorization. - You can find your HF token at https://huggingface.co/settings/tokens. - :param generation_kwargs: A dictionary containing keyword arguments to customize text generation. - Some examples: `max_new_tokens`, `temperature`, `top_k`, `top_p`,... - See Hugging Face's [documentation](https://huggingface.co/docs/huggingface_hub/v0.18.0.rc0/en/package_reference/inference_client#huggingface_hub.inference._text_generation.TextGenerationParameters) - for more information. - :param stop_words: An optional list of strings representing the stop words. - :param streaming_callback: An optional callable for handling streaming responses. - """ - transformers_import.check() - - if url: - r = urlparse(url) - is_valid_url = all([r.scheme in ["http", "https"], r.netloc]) - if not is_valid_url: - raise ValueError(f"Invalid TGI endpoint URL provided: {url}") - - check_valid_model(model, token) - - # handle generation kwargs setup - generation_kwargs = generation_kwargs.copy() if generation_kwargs else {} - check_generation_params(generation_kwargs, ["n"]) - generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", []) - generation_kwargs["stop_sequences"].extend(stop_words or []) - - self.model = model - self.url = url - self.chat_template = chat_template - self.token = token - self.generation_kwargs = generation_kwargs - self.client = InferenceClient(url or model, token=token) - self.streaming_callback = streaming_callback - self.tokenizer = None - - def warm_up(self) -> None: - """ - Load the tokenizer. - """ - self.tokenizer = AutoTokenizer.from_pretrained(self.model, token=self.token) - # mypy can't infer that chat_template attribute exists on the object returned by AutoTokenizer.from_pretrained - chat_template = getattr(self.tokenizer, "chat_template", None) - if not chat_template and not self.chat_template: - logger.warning( - "The model '%s' doesn't have a default chat_template, and no chat_template was supplied during " - "this component's initialization. It’s possible that the model doesn't support ChatML inference " - "format, potentially leading to unexpected behavior.", - self.model, - ) - - def to_dict(self) -> Dict[str, Any]: - """ - Serialize this component to a dictionary. - - :return: A dictionary containing the serialized component. - """ - callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None - return default_to_dict( - self, - model=self.model, - url=self.url, - chat_template=self.chat_template, - token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens - generation_kwargs=self.generation_kwargs, - streaming_callback=callback_name, - ) - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceTGIChatGenerator": - """ - Deserialize this component from a dictionary. - """ - init_params = data.get("init_parameters", {}) - serialized_callback_handler = init_params.get("streaming_callback") - if serialized_callback_handler: - data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler) - return default_from_dict(cls, data) - - def _get_telemetry_data(self) -> Dict[str, Any]: - """ - Data that is sent to Posthog for usage analytics. - """ - # Don't send URL as it is sensitive information - return {"model": self.model} - - @component.output_types(replies=List[ChatMessage]) - def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None): - """ - Invoke the text generation inference based on the provided messages and generation parameters. - - :param messages: A list of ChatMessage instances representing the input messages. - :param generation_kwargs: Additional keyword arguments for text generation. - :return: A list containing the generated responses as ChatMessage instances. - """ - - # check generation kwargs given as parameters to override the default ones - additional_params = ["n", "stop_words"] - check_generation_params(generation_kwargs, additional_params) - - # update generation kwargs by merging with the default ones - generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} - num_responses = generation_kwargs.pop("n", 1) - - # merge stop_words and stop_sequences into a single list - generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", []) - generation_kwargs["stop_sequences"].extend(generation_kwargs.pop("stop_words", [])) - - if self.tokenizer is None: - raise RuntimeError("Please call warm_up() before running LLM inference.") - - # apply either model's chat template or the user-provided one - prepared_prompt: str = self.tokenizer.apply_chat_template( - conversation=messages, chat_template=self.chat_template, tokenize=False - ) - prompt_token_count: int = len(self.tokenizer.encode(prepared_prompt, add_special_tokens=False)) - - if self.streaming_callback: - if num_responses > 1: - raise ValueError("Cannot stream multiple responses, please set n=1.") - - return self._run_streaming(prepared_prompt, prompt_token_count, generation_kwargs) - - return self._run_non_streaming(prepared_prompt, prompt_token_count, num_responses, generation_kwargs) - - def _run_streaming( - self, prepared_prompt: str, prompt_token_count: int, generation_kwargs: Dict[str, Any] - ) -> Dict[str, List[ChatMessage]]: - res: Iterable[TextGenerationStreamResponse] = self.client.text_generation( - prepared_prompt, stream=True, details=True, **generation_kwargs - ) - chunk = None - # pylint: disable=not-an-iterable - for chunk in res: - token: Token = chunk.token - if token.special: - continue - chunk_metadata = {**asdict(token), **(asdict(chunk.details) if chunk.details else {})} - stream_chunk = StreamingChunk(token.text, chunk_metadata) - self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method) - - message = ChatMessage.from_assistant(chunk.generated_text) - message.metadata.update( - { - "finish_reason": chunk.details.finish_reason.value, - "index": 0, - "model": self.client.model, - "usage": { - "completion_tokens": chunk.details.generated_tokens, - "prompt_tokens": prompt_token_count, - "total_tokens": prompt_token_count + chunk.details.generated_tokens, - }, - } - ) - return {"replies": [message]} - - def _run_non_streaming( - self, prepared_prompt: str, prompt_token_count: int, num_responses: int, generation_kwargs: Dict[str, Any] - ) -> Dict[str, List[ChatMessage]]: - chat_messages: List[ChatMessage] = [] - for _i in range(num_responses): - tgr: TextGenerationResponse = self.client.text_generation( - prepared_prompt, details=True, **generation_kwargs - ) - message = ChatMessage.from_assistant(tgr.generated_text) - message.metadata.update( - { - "finish_reason": tgr.details.finish_reason.value, - "index": _i, - "model": self.client.model, - "usage": { - "completion_tokens": len(tgr.details.tokens), - "prompt_tokens": prompt_token_count, - "total_tokens": prompt_token_count + len(tgr.details.tokens), - }, - } - ) - chat_messages.append(message) - return {"replies": chat_messages} diff --git a/haystack/preview/components/generators/chat/openai.py b/haystack/preview/components/generators/chat/openai.py deleted file mode 100644 index 0b219378ee..0000000000 --- a/haystack/preview/components/generators/chat/openai.py +++ /dev/null @@ -1,287 +0,0 @@ -import dataclasses -import logging -import os -from typing import Optional, List, Callable, Dict, Any - -import openai -from openai.openai_object import OpenAIObject - -from haystack.preview import component, default_from_dict, default_to_dict -from haystack.preview.components.generators.utils import serialize_callback_handler, deserialize_callback_handler -from haystack.preview.dataclasses import StreamingChunk, ChatMessage - -logger = logging.getLogger(__name__) - - -API_BASE_URL = "https://api.openai.com/v1" - - -@component -class GPTChatGenerator: - """ - Enables text generation using OpenAI's large language models (LLMs). It supports gpt-4 and gpt-3.5-turbo - family of models accessed through the chat completions API endpoint. - - Users can pass any text generation parameters valid for the `openai.ChatCompletion.create` method - directly to this component via the `**generation_kwargs` parameter in __init__ or the `**generation_kwargs` - parameter in `run` method. - - For more details on the parameters supported by the OpenAI API, refer to the OpenAI - [documentation](https://platform.openai.com/docs/api-reference/chat). - - ```python - from haystack.preview.components.generators.chat import GPTChatGenerator - from haystack.preview.dataclasses import ChatMessage - - messages = [ChatMessage.from_user("What's Natural Language Processing?")] - - client = GPTChatGenerator() - response = client.run(messages) - print(response) - - >>{'replies': [ChatMessage(content='Natural Language Processing (NLP) is a branch of artificial intelligence - >>that focuses on enabling computers to understand, interpret, and generate human language in a way that is - >>meaningful and useful.', role=, name=None, - >>metadata={'model': 'gpt-3.5-turbo-0613', 'index': 0, 'finish_reason': 'stop', - >>'usage': {'prompt_tokens': 15, 'completion_tokens': 36, 'total_tokens': 51}})]} - - ``` - - Key Features and Compatibility: - - **Primary Compatibility**: Designed to work seamlessly with the OpenAI API Chat Completion endpoint - and gpt-4 and gpt-3.5-turbo family of models. - - **Streaming Support**: Supports streaming responses from the OpenAI API Chat Completion endpoint. - - **Customizability**: Supports all parameters supported by the OpenAI API Chat Completion endpoint. - - Input and Output Format: - - **ChatMessage Format**: This component uses the ChatMessage format for structuring both input and output, - ensuring coherent and contextually relevant responses in chat-based text generation scenarios. Details on the - ChatMessage format can be found at: https://github.com/openai/openai-python/blob/main/chatml.md. - """ - - def __init__( - self, - api_key: Optional[str] = None, - model_name: str = "gpt-3.5-turbo", - streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, - api_base_url: str = API_BASE_URL, - generation_kwargs: Optional[Dict[str, Any]] = None, - ): - """ - Creates an instance of ChatGPTGenerator. Unless specified otherwise in the `model_name`, this is for OpenAI's - GPT-3.5 model. - - :param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the - environment variable OPENAI_API_KEY (recommended). - :param model_name: The name of the model to use. - :param streaming_callback: A callback function that is called when a new token is received from the stream. - The callback function accepts StreamingChunk as an argument. - :param api_base_url: The OpenAI API Base url, defaults to `https://api.openai.com/v1`. - :param generation_kwargs: Other parameters to use for the model. These parameters are all sent directly to - the OpenAI endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat) for - more details. - Some of the supported parameters: - - `max_tokens`: The maximum number of tokens the output text can have. - - `temperature`: What sampling temperature to use. Higher values mean the model will take more risks. - Try 0.9 for more creative applications and 0 (argmax sampling) for ones with a well-defined answer. - - `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model - considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens - comprising the top 10% probability mass are considered. - - `n`: How many completions to generate for each prompt. For example, if the LLM gets 3 prompts and n is 2, - it will generate two completions for each of the three prompts, ending up with 6 completions in total. - - `stop`: One or more sequences after which the LLM should stop generating tokens. - - `presence_penalty`: What penalty to apply if a token is already present at all. Bigger values mean - the model will be less likely to repeat the same token in the text. - - `frequency_penalty`: What penalty to apply if a token has already been generated in the text. - Bigger values mean the model will be less likely to repeat the same token in the text. - - `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the - values are the bias to add to that token. - """ - # if the user does not provide the API key, check if it is set in the module client - api_key = api_key or openai.api_key - if api_key is None: - try: - api_key = os.environ["OPENAI_API_KEY"] - except KeyError as e: - raise ValueError( - "GPTChatGenerator expects an OpenAI API key. " - "Set the OPENAI_API_KEY environment variable (recommended) or pass it explicitly." - ) from e - openai.api_key = api_key - - self.model_name = model_name - self.generation_kwargs = generation_kwargs or {} - self.streaming_callback = streaming_callback - - self.api_base_url = api_base_url - openai.api_base = api_base_url - - def _get_telemetry_data(self) -> Dict[str, Any]: - """ - Data that is sent to Posthog for usage analytics. - """ - return {"model": self.model_name} - - def to_dict(self) -> Dict[str, Any]: - """ - Serialize this component to a dictionary. - :return: The serialized component as a dictionary. - """ - callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None - return default_to_dict( - self, - model_name=self.model_name, - streaming_callback=callback_name, - api_base_url=self.api_base_url, - generation_kwargs=self.generation_kwargs, - ) - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "GPTChatGenerator": - """ - Deserialize this component from a dictionary. - :param data: The dictionary representation of this component. - :return: The deserialized component instance. - """ - init_params = data.get("init_parameters", {}) - serialized_callback_handler = init_params.get("streaming_callback") - if serialized_callback_handler: - data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler) - return default_from_dict(cls, data) - - @component.output_types(replies=List[ChatMessage]) - def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None): - """ - Invoke the text generation inference based on the provided messages and generation parameters. - - :param messages: A list of ChatMessage instances representing the input messages. - :param generation_kwargs: Additional keyword arguments for text generation. These parameters will - potentially override the parameters passed in the __init__ method. - For more details on the parameters supported by the OpenAI API, refer to the - OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat/create). - :return: A list containing the generated responses as ChatMessage instances. - """ - - # update generation kwargs by merging with the generation kwargs passed to the run method - generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} - - # adapt ChatMessage(s) to the format expected by the OpenAI API - openai_formatted_messages = self._convert_to_openai_format(messages) - - completion = openai.ChatCompletion.create( - model=self.model_name, - messages=openai_formatted_messages, - stream=self.streaming_callback is not None, - **generation_kwargs, - ) - - completions: List[ChatMessage] - if self.streaming_callback: - num_responses = generation_kwargs.pop("n", 1) - if num_responses > 1: - raise ValueError("Cannot stream multiple responses, please set n=1.") - chunks: List[StreamingChunk] = [] - chunk = None - for chunk in completion: - if chunk.choices: - chunk_delta: StreamingChunk = self._build_chunk(chunk, chunk.choices[0]) - chunks.append(chunk_delta) - self.streaming_callback(chunk_delta) # invoke callback with the chunk_delta - completions = [self._connect_chunks(chunk, chunks)] - else: - completions = [self._build_message(completion, choice) for choice in completion.choices] - - # before returning, do post-processing of the completions - for message in completions: - self._check_finish_reason(message) - - return {"replies": completions} - - def _convert_to_openai_format(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]: - """ - Converts the list of ChatMessage to the list of messages in the format expected by the OpenAI API. - :param messages: The list of ChatMessage. - :return: The list of messages in the format expected by the OpenAI API. - """ - openai_chat_message_format = {"role", "content", "name"} - openai_formatted_messages = [] - for m in messages: - message_dict = dataclasses.asdict(m) - filtered_message = {k: v for k, v in message_dict.items() if k in openai_chat_message_format and v} - openai_formatted_messages.append(filtered_message) - return openai_formatted_messages - - def _connect_chunks(self, chunk: OpenAIObject, chunks: List[StreamingChunk]) -> ChatMessage: - """ - Connects the streaming chunks into a single ChatMessage. - :param chunk: The last chunk returned by the OpenAI API. - :param chunks: The list of all chunks returned by the OpenAI API. - """ - complete_response = ChatMessage.from_assistant("".join([chunk.content for chunk in chunks])) - complete_response.metadata.update( - { - "model": chunk.model, - "index": 0, - "finish_reason": chunk.choices[0].finish_reason, - "usage": {}, # we don't have usage data for streaming responses - } - ) - return complete_response - - def _build_message(self, completion: OpenAIObject, choice: OpenAIObject) -> ChatMessage: - """ - Converts the non-streaming response from the OpenAI API to a ChatMessage. - :param completion: The completion returned by the OpenAI API. - :param choice: The choice returned by the OpenAI API. - :return: The ChatMessage. - """ - message: OpenAIObject = choice.message - # message.content is str but message.function_call is OpenAIObject but JSON in fact, convert to str - content = str(message.function_call) if choice.finish_reason == "function_call" else message.content - chat_message = ChatMessage.from_assistant(content) - chat_message.metadata.update( - { - "model": completion.model, - "index": choice.index, - "finish_reason": choice.finish_reason, - "usage": dict(completion.usage.items()), - } - ) - return chat_message - - def _build_chunk(self, chunk: OpenAIObject, choice: OpenAIObject) -> StreamingChunk: - """ - Converts the streaming response chunk from the OpenAI API to a StreamingChunk. - :param chunk: The chunk returned by the OpenAI API. - :param choice: The choice returned by the OpenAI API. - :return: The StreamingChunk. - """ - has_content = bool(hasattr(choice.delta, "content") and choice.delta.content) - if has_content: - content = choice.delta.content - elif hasattr(choice.delta, "function_call"): - content = choice.delta.function_call - else: - content = "" - chunk_message = StreamingChunk(content) - chunk_message.metadata.update( - {"model": chunk.model, "index": choice.index, "finish_reason": choice.finish_reason} - ) - return chunk_message - - def _check_finish_reason(self, message: ChatMessage) -> None: - """ - Check the `finish_reason` returned with the OpenAI completions. - If the `finish_reason` is `length` or `content_filter`, log a warning. - :param message: The message returned by the LLM. - """ - if message.metadata["finish_reason"] == "length": - logger.warning( - "The completion for index %s has been truncated before reaching a natural stopping point. " - "Increase the max_tokens parameter to allow for longer completions.", - message.metadata["index"], - ) - if message.metadata["finish_reason"] == "content_filter": - logger.warning( - "The completion for index %s has been truncated due to the content filter.", message.metadata["index"] - ) diff --git a/haystack/preview/components/generators/cohere.py b/haystack/preview/components/generators/cohere.py deleted file mode 100644 index ee7106a5b1..0000000000 --- a/haystack/preview/components/generators/cohere.py +++ /dev/null @@ -1,159 +0,0 @@ -import logging -import os -import sys -from typing import Any, Callable, Dict, List, Optional - -from haystack.preview.lazy_imports import LazyImport -from haystack.preview import DeserializationError, component, default_from_dict, default_to_dict - -with LazyImport(message="Run 'pip install cohere'") as cohere_import: - from cohere import Client, COHERE_API_URL - -logger = logging.getLogger(__name__) - - -@component -class CohereGenerator: - """LLM Generator compatible with Cohere's generate endpoint. - - Queries the LLM using Cohere's API. Invocations are made using 'cohere' package. - See [Cohere API](https://docs.cohere.com/reference/generate) for more details. - - Example usage: - - ```python - from haystack.preview.generators import CohereGenerator - generator = CohereGenerator(api_key="test-api-key") - generator.run(prompt="What's the capital of France?") - ``` - """ - - def __init__( - self, - api_key: Optional[str] = None, - model_name: str = "command", - streaming_callback: Optional[Callable] = None, - api_base_url: Optional[str] = None, - **kwargs, - ): - """ - Instantiates a `CohereGenerator` component. - :param api_key: The API key for the Cohere API. If not set, it will be read from the COHERE_API_KEY env var. - :param model_name: The name of the model to use. Available models are: [command, command-light, command-nightly, command-nightly-light]. Defaults to "command". - :param streaming_callback: A callback function to be called with the streaming response. Defaults to None. - :param api_base_url: The base URL of the Cohere API. Defaults to "https://api.cohere.ai". - :param kwargs: Additional model parameters. These will be used during generation. Refer to https://docs.cohere.com/reference/generate for more details. - Some of the parameters are: - - 'max_tokens': The maximum number of tokens to be generated. Defaults to 1024. - - 'truncate': One of NONE|START|END to specify how the API will handle inputs longer than the maximum token length. Defaults to END. - - 'temperature': A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations. - - 'preset': Identifier of a custom preset. A preset is a combination of parameters, such as prompt, temperature etc. You can create presets in the playground. - - 'end_sequences': The generated text will be cut at the beginning of the earliest occurrence of an end sequence. The sequence will be excluded from the text. - - 'stop_sequences': The generated text will be cut at the end of the earliest occurrence of a stop sequence. The sequence will be included the text. - - 'k': Defaults to 0, min value of 0.01, max value of 0.99. - - 'p': Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. - - 'frequency_penalty': Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, - proportional to how many times they have already appeared in the prompt or prior generation.' - - 'presence_penalty': Defaults to 0.0, min value of 0.0, max value of 1.0. Can be used to reduce repetitiveness of generated tokens. - Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. - - 'return_likelihoods': One of GENERATION|ALL|NONE to specify how and if the token likelihoods are returned with the response. Defaults to NONE. - - 'logit_bias': Used to prevent the model from generating unwanted tokens or to incentivize it to include desired tokens. - The format is {token_id: bias} where bias is a float between -10 and 10. - - """ - cohere_import.check() - - if not api_key: - api_key = os.environ.get("COHERE_API_KEY") - if not api_key: - raise ValueError( - "CohereGenerator needs an API key to run. Either provide it as init parameter or set the env var COHERE_API_KEY." - ) - - if not api_base_url: - api_base_url = COHERE_API_URL - - self.api_key = api_key - self.model_name = model_name - self.streaming_callback = streaming_callback - self.api_base_url = api_base_url - self.model_parameters = kwargs - self.client = Client(api_key=self.api_key, api_url=self.api_base_url) - - def to_dict(self) -> Dict[str, Any]: - """ - Serialize this component to a dictionary. - """ - if self.streaming_callback: - module = self.streaming_callback.__module__ - if module == "builtins": - callback_name = self.streaming_callback.__name__ - else: - callback_name = f"{module}.{self.streaming_callback.__name__}" - else: - callback_name = None - - return default_to_dict( - self, - model_name=self.model_name, - streaming_callback=callback_name, - api_base_url=self.api_base_url, - **self.model_parameters, - ) - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "CohereGenerator": - """ - Deserialize this component from a dictionary. - """ - init_params = data.get("init_parameters", {}) - streaming_callback = None - if "streaming_callback" in init_params and init_params["streaming_callback"]: - parts = init_params["streaming_callback"].split(".") - module_name = ".".join(parts[:-1]) - function_name = parts[-1] - module = sys.modules.get(module_name, None) - if not module: - raise DeserializationError(f"Could not locate the module of the streaming callback: {module_name}") - streaming_callback = getattr(module, function_name, None) - if not streaming_callback: - raise DeserializationError(f"Could not locate the streaming callback: {function_name}") - data["init_parameters"]["streaming_callback"] = streaming_callback - return default_from_dict(cls, data) - - @component.output_types(replies=List[str], metadata=List[Dict[str, Any]]) - def run(self, prompt: str): - """ - Queries the LLM with the prompts to produce replies. - :param prompt: The prompt to be sent to the generative model. - """ - response = self.client.generate( - model=self.model_name, prompt=prompt, stream=self.streaming_callback is not None, **self.model_parameters - ) - if self.streaming_callback: - metadata_dict: Dict[str, Any] = {} - for chunk in response: - self.streaming_callback(chunk) - metadata_dict["index"] = chunk.index - replies = response.texts - metadata_dict["finish_reason"] = response.finish_reason - metadata = [metadata_dict] - self._check_truncated_answers(metadata) - return {"replies": replies, "metadata": metadata} - - metadata = [{"finish_reason": resp.finish_reason} for resp in response] - replies = [resp.text for resp in response] - self._check_truncated_answers(metadata) - return {"replies": replies, "metadata": metadata} - - def _check_truncated_answers(self, metadata: List[Dict[str, Any]]): - """ - Check the `finish_reason` returned with the Cohere response. - If the `finish_reason` is `MAX_TOKEN`, log a warning to the user. - :param metadata: The metadata returned by the Cohere API. - """ - if metadata[0]["finish_reason"] == "MAX_TOKENS": - logger.warning( - "Responses have been truncated before reaching a natural stopping point. " - "Increase the max_tokens parameter to allow for longer completions." - ) diff --git a/haystack/preview/components/generators/hf_utils.py b/haystack/preview/components/generators/hf_utils.py deleted file mode 100644 index 9eca92ae5a..0000000000 --- a/haystack/preview/components/generators/hf_utils.py +++ /dev/null @@ -1,57 +0,0 @@ -import inspect -from typing import Any, Dict, List, Optional - -from haystack.preview.lazy_imports import LazyImport - -with LazyImport(message="Run 'pip install transformers'") as transformers_import: - from huggingface_hub import InferenceClient, HfApi - from huggingface_hub.utils import RepositoryNotFoundError - - -def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepted_params: Optional[List[str]] = None): - """ - Check the provided generation parameters for validity. - - :param kwargs: A dictionary containing the generation parameters. - :param additional_accepted_params: An optional list of strings representing additional accepted parameters. - :raises ValueError: If any unknown text generation parameters are provided. - """ - transformers_import.check() - - if kwargs: - accepted_params = { - param - for param in inspect.signature(InferenceClient.text_generation).parameters.keys() - if param not in ["self", "prompt"] - } - if additional_accepted_params: - accepted_params.update(additional_accepted_params) - unknown_params = set(kwargs.keys()) - accepted_params - if unknown_params: - raise ValueError( - f"Unknown text generation parameters: {unknown_params}. The valid parameters are: {accepted_params}." - ) - - -def check_valid_model(model_id: str, token: Optional[str]) -> None: - """ - Check if the provided model ID corresponds to a valid model on HuggingFace Hub. - Also check if the model is a text generation model. - - :param model_id: A string representing the HuggingFace model ID. - :param token: An optional string representing the authentication token. - :raises ValueError: If the model is not found or is not a text generation model. - """ - transformers_import.check() - - api = HfApi() - try: - model_info = api.model_info(model_id, token=token) - except RepositoryNotFoundError as e: - raise ValueError( - f"Model {model_id} not found on HuggingFace Hub. Please provide a valid HuggingFace model_id." - ) from e - - allowed_model = model_info.pipeline_tag in ["text-generation", "text2text-generation"] - if not allowed_model: - raise ValueError(f"Model {model_id} is not a text generation model. Please provide a text generation model.") diff --git a/haystack/preview/components/generators/hugging_face_local.py b/haystack/preview/components/generators/hugging_face_local.py deleted file mode 100644 index 91dc639588..0000000000 --- a/haystack/preview/components/generators/hugging_face_local.py +++ /dev/null @@ -1,236 +0,0 @@ -import logging -from typing import Any, Dict, List, Literal, Optional, Union -from copy import deepcopy - -from haystack.preview import component, default_to_dict -from haystack.preview.lazy_imports import LazyImport - -logger = logging.getLogger(__name__) - -SUPPORTED_TASKS = ["text-generation", "text2text-generation"] - -with LazyImport(message="Run 'pip install transformers[torch]'") as torch_and_transformers_import: - import torch - from huggingface_hub import model_info - from transformers import ( - pipeline, - StoppingCriteriaList, - StoppingCriteria, - PreTrainedTokenizer, - PreTrainedTokenizerFast, - ) - - class StopWordsCriteria(StoppingCriteria): - """ - Stops text generation if any one of the stop words is generated. - - Note: When a stop word is encountered, the generation of new text is stopped. - However, if the stop word is in the prompt itself, it can stop generating new text - prematurely after the first token. This is particularly important for LLMs designed - for dialogue generation. For these models, like for example mosaicml/mpt-7b-chat, - the output includes both the new text and the original prompt. Therefore, it's important - to make sure your prompt has no stop words. - """ - - def __init__( - self, - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - stop_words: List[str], - device: Union[str, torch.device] = "cpu", - ): - super().__init__() - encoded_stop_words = tokenizer(stop_words, add_special_tokens=False, padding=True, return_tensors="pt") - self.stop_ids = encoded_stop_words.input_ids.to(device) - - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: - for stop_id in self.stop_ids: - found_stop_word = self.is_stop_word_found(input_ids, stop_id) - if found_stop_word: - return True - return False - - def is_stop_word_found(self, generated_text_ids: torch.Tensor, stop_id: torch.Tensor) -> bool: - generated_text_ids = generated_text_ids[-1] - len_generated_text_ids = generated_text_ids.size(0) - len_stop_id = stop_id.size(0) - result = all(generated_text_ids[len_generated_text_ids - len_stop_id :].eq(stop_id)) - return result - - -@component -class HuggingFaceLocalGenerator: - """ - Generator based on a Hugging Face model. - This component provides an interface to generate text using a Hugging Face model that runs locally. - - Usage example: - ```python - from haystack.preview.components.generators import HuggingFaceLocalGenerator - - generator = HuggingFaceLocalGenerator(model="google/flan-t5-large", - task="text2text-generation", - generation_kwargs={ - "max_new_tokens": 100, - "temperature": 0.9, - }) - - print(generator.run("Who is the best American actor?")) - # {'replies': ['John Cusack']} - ``` - """ - - def __init__( - self, - model_name_or_path: str = "google/flan-t5-base", - task: Optional[Literal["text-generation", "text2text-generation"]] = None, - device: Optional[str] = None, - token: Optional[Union[str, bool]] = None, - generation_kwargs: Optional[Dict[str, Any]] = None, - huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None, - stop_words: Optional[List[str]] = None, - ): - """ - :param model_name_or_path: The name or path of a Hugging Face model for text generation, - for example, "google/flan-t5-large". - If the model is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored. - :param task: The task for the Hugging Face pipeline. - Possible values are "text-generation" and "text2text-generation". - Generally, decoder-only models like GPT support "text-generation", - while encoder-decoder models like T5 support "text2text-generation". - If the task is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored. - If not specified, the component will attempt to infer the task from the model name, - calling the Hugging Face Hub API. - :param device: The device on which the model is loaded. (e.g., "cpu", "cuda:0"). - If `device` or `device_map` is specified in the `huggingface_pipeline_kwargs`, - this parameter will be ignored. - :param token: The token to use as HTTP bearer authorization for remote files. - If True, will use the token generated when running huggingface-cli login (stored in ~/.huggingface). - If the token is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored. - :param generation_kwargs: A dictionary containing keyword arguments to customize text generation. - Some examples: `max_length`, `max_new_tokens`, `temperature`, `top_k`, `top_p`,... - See Hugging Face's documentation for more information: - - https://huggingface.co/docs/transformers/main/en/generation_strategies#customize-text-generation - - https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig - :param huggingface_pipeline_kwargs: Dictionary containing keyword arguments used to initialize the - Hugging Face pipeline for text generation. - These keyword arguments provide fine-grained control over the Hugging Face pipeline. - In case of duplication, these kwargs override `model_name_or_path`, `task`, `device`, and `token` init parameters. - See Hugging Face's [documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline.task) - for more information on the available kwargs. - In this dictionary, you can also include `model_kwargs` to specify the kwargs - for model initialization: - https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.from_pretrained - :param stop_words: A list of stop words. If any one of the stop words is generated, the generation is stopped. - If you provide this parameter, you should not specify the `stopping_criteria` in `generation_kwargs`. - For some chat models, the output includes both the new text and the original prompt. - In these cases, it's important to make sure your prompt has no stop words. - """ - torch_and_transformers_import.check() - - huggingface_pipeline_kwargs = huggingface_pipeline_kwargs or {} - generation_kwargs = generation_kwargs or {} - - # check if the huggingface_pipeline_kwargs contain the essential parameters - # otherwise, populate them with values from other init parameters - huggingface_pipeline_kwargs.setdefault("model", model_name_or_path) - huggingface_pipeline_kwargs.setdefault("token", token) - if ( - device is not None - and "device" not in huggingface_pipeline_kwargs - and "device_map" not in huggingface_pipeline_kwargs - ): - huggingface_pipeline_kwargs["device"] = device - - # task identification and validation - if task is None: - if "task" in huggingface_pipeline_kwargs: - task = huggingface_pipeline_kwargs["task"] - elif isinstance(huggingface_pipeline_kwargs["model"], str): - task = model_info( - huggingface_pipeline_kwargs["model"], token=huggingface_pipeline_kwargs["token"] - ).pipeline_tag - - if task not in SUPPORTED_TASKS: - raise ValueError( - f"Task '{task}' is not supported. " f"The supported tasks are: {', '.join(SUPPORTED_TASKS)}." - ) - huggingface_pipeline_kwargs["task"] = task - - # if not specified, set return_full_text to False for text-generation - # only generated text is returned (excluding prompt) - if task == "text-generation": - generation_kwargs.setdefault("return_full_text", False) - - if stop_words and "stopping_criteria" in generation_kwargs: - raise ValueError( - "Found both the `stop_words` init parameter and the `stopping_criteria` key in `generation_kwargs`. " - "Please specify only one of them." - ) - - self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs - self.generation_kwargs = generation_kwargs - self.stop_words = stop_words - self.pipeline = None - self.stopping_criteria_list = None - - def _get_telemetry_data(self) -> Dict[str, Any]: - """ - Data that is sent to Posthog for usage analytics. - """ - if isinstance(self.huggingface_pipeline_kwargs["model"], str): - return {"model": self.huggingface_pipeline_kwargs["model"]} - return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"} - - def warm_up(self): - if self.pipeline is None: - self.pipeline = pipeline(**self.huggingface_pipeline_kwargs) - - if self.stop_words and self.stopping_criteria_list is None: - stop_words_criteria = StopWordsCriteria( - tokenizer=self.pipeline.tokenizer, stop_words=self.stop_words, device=self.pipeline.device - ) - self.stopping_criteria_list = StoppingCriteriaList([stop_words_criteria]) - - def to_dict(self) -> Dict[str, Any]: - """ - Serialize this component to a dictionary. - """ - pipeline_kwargs_to_serialize = deepcopy(self.huggingface_pipeline_kwargs) - - # we don't want to serialize valid tokens - if isinstance(pipeline_kwargs_to_serialize["token"], str): - pipeline_kwargs_to_serialize["token"] = None - - return default_to_dict( - self, - huggingface_pipeline_kwargs=pipeline_kwargs_to_serialize, - generation_kwargs=self.generation_kwargs, - stop_words=self.stop_words, - ) - - @component.output_types(replies=List[str]) - def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): - """ - Run the text generation model on the given prompt. - - :param prompt: A string representing the prompt. - :param generation_kwargs: Additional keyword arguments for text generation. - :return: A dictionary containing the generated replies. - """ - if self.pipeline is None: - raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.") - - if not prompt: - return {"replies": []} - - # merge generation kwargs from init method with those from run method - updated_generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} - - output = self.pipeline(prompt, stopping_criteria=self.stopping_criteria_list, **updated_generation_kwargs) - replies = [o["generated_text"] for o in output if "generated_text" in o] - - if self.stop_words: - # the output of the pipeline includes the stop word - replies = [reply.replace(stop_word, "").rstrip() for reply in replies for stop_word in self.stop_words] - - return {"replies": replies} diff --git a/haystack/preview/components/generators/hugging_face_tgi.py b/haystack/preview/components/generators/hugging_face_tgi.py deleted file mode 100644 index 71dc64acd7..0000000000 --- a/haystack/preview/components/generators/hugging_face_tgi.py +++ /dev/null @@ -1,237 +0,0 @@ -import logging -from dataclasses import asdict -from typing import Any, Dict, List, Optional, Iterable, Callable -from urllib.parse import urlparse - -from haystack.preview import component, default_to_dict, default_from_dict -from haystack.preview.components.generators.utils import serialize_callback_handler, deserialize_callback_handler -from haystack.preview.dataclasses import StreamingChunk -from haystack.preview.components.generators.hf_utils import check_generation_params, check_valid_model -from haystack.preview.lazy_imports import LazyImport - -with LazyImport(message="Run 'pip install transformers'") as transformers_import: - from huggingface_hub import InferenceClient - from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, TextGenerationResponse, Token - from transformers import AutoTokenizer - - -logger = logging.getLogger(__name__) - - -@component -class HuggingFaceTGIGenerator: - """ - Enables text generation using HuggingFace Hub hosted non-chat LLMs. This component is designed to seamlessly - inference models deployed on the Text Generation Inference (TGI) backend. - - You can use this component for LLMs hosted on Hugging Face inference endpoints, the rate-limited - Inference API tier: - - ```python - from haystack.preview.components.generators import HuggingFaceTGIGenerator - client = HuggingFaceTGIGenerator(model="mistralai/Mistral-7B-v0.1", token="") - client.warm_up() - response = client.run("What's Natural Language Processing?", max_new_tokens=120) - print(response) - ``` - - Or for LLMs hosted on paid https://huggingface.co/inference-endpoints endpoint, and/or your own custom TGI endpoint. - In these two cases, you'll need to provide the URL of the endpoint as well as a valid token: - - ```python - from haystack.preview.components.generators import HuggingFaceTGIGenerator - client = HuggingFaceTGIGenerator(model="mistralai/Mistral-7B-v0.1", - url="", - token="") - client.warm_up() - response = client.run("What's Natural Language Processing?", max_new_tokens=120) - print(response) - ``` - - - Key Features and Compatibility: - - **Primary Compatibility**: Designed to work seamlessly with any non-chat model deployed using the TGI - framework. For more information on TGI, visit https://github.com/huggingface/text-generation-inference. - - **Hugging Face Inference Endpoints**: Supports inference of TGI chat LLMs deployed on Hugging Face - inference endpoints. For more details refer to https://huggingface.co/inference-endpoints. - - **Inference API Support**: Supports inference of TGI LLMs hosted on the rate-limited Inference - API tier. Learn more about the Inference API at: https://huggingface.co/inference-api - Discover available LLMs using the following command: - ``` - wget -qO- https://api-inference.huggingface.co/framework/text-generation-inference - ``` - And simply use the model ID as the model parameter for this component. You'll also need to provide a valid - Hugging Face API token as the token parameter. - - **Custom TGI Endpoints**: Supports inference of LLMs deployed on custom TGI endpoints. Anyone can - deploy their own TGI endpoint using the TGI framework. For more details refer - to https://huggingface.co/inference-endpoints. - Input and Output Format: - - **String Format**: This component uses the str format for structuring both input and output, - ensuring coherent and contextually relevant responses in text generation scenarios. - """ - - def __init__( - self, - model: str = "mistralai/Mistral-7B-v0.1", - url: Optional[str] = None, - token: Optional[str] = None, - generation_kwargs: Optional[Dict[str, Any]] = None, - stop_words: Optional[List[str]] = None, - streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, - ): - """ - Initialize the HuggingFaceTGIGenerator instance. - - :param model: A string representing the model id on HF Hub. Default is "mistralai/Mistral-7B-v0.1". - :param url: An optional string representing the URL of the TGI endpoint. - :param token: The HuggingFace token to use as HTTP bearer authorization - You can find your HF token at https://huggingface.co/settings/tokens - :param generation_kwargs: A dictionary containing keyword arguments to customize text generation. - Some examples: `max_new_tokens`, `temperature`, `top_k`, `top_p`,... - See Hugging Face's documentation for more information at: - https://huggingface.co/docs/huggingface_hub/v0.18.0.rc0/en/package_reference/inference_client#huggingface_hub.inference._text_generation.TextGenerationParameters - :param stop_words: An optional list of strings representing the stop words. - :param streaming_callback: An optional callable for handling streaming responses. - """ - transformers_import.check() - - if url: - r = urlparse(url) - is_valid_url = all([r.scheme in ["http", "https"], r.netloc]) - if not is_valid_url: - raise ValueError(f"Invalid TGI endpoint URL provided: {url}") - - check_valid_model(model, token) - - # handle generation kwargs setup - generation_kwargs = generation_kwargs.copy() if generation_kwargs else {} - check_generation_params(generation_kwargs, ["n"]) - generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", []) - generation_kwargs["stop_sequences"].extend(stop_words or []) - - self.model = model - self.url = url - self.token = token - self.generation_kwargs = generation_kwargs - self.client = InferenceClient(url or model, token=token) - self.streaming_callback = streaming_callback - self.tokenizer = None - - def warm_up(self) -> None: - """ - Load the tokenizer - """ - self.tokenizer = AutoTokenizer.from_pretrained(self.model, token=self.token) - - def to_dict(self) -> Dict[str, Any]: - """ - Serialize this component to a dictionary. - - :return: A dictionary containing the serialized component. - """ - callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None - return default_to_dict( - self, - model=self.model, - url=self.url, - token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens - generation_kwargs=self.generation_kwargs, - streaming_callback=callback_name, - ) - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceTGIGenerator": - """ - Deserialize this component from a dictionary. - """ - init_params = data.get("init_parameters", {}) - serialized_callback_handler = init_params.get("streaming_callback") - if serialized_callback_handler: - data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler) - return default_from_dict(cls, data) - - def _get_telemetry_data(self) -> Dict[str, Any]: - """ - Data that is sent to Posthog for usage analytics. - """ - # Don't send URL as it is sensitive information - return {"model": self.model} - - @component.output_types(replies=List[str], metadata=List[Dict[str, Any]]) - def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): - """ - Invoke the text generation inference for the given prompt and generation parameters. - - :param prompt: A string representing the prompt. - :param generation_kwargs: Additional keyword arguments for text generation. - :return: A dictionary containing the generated replies and metadata. Both are lists of length n. - Replies are strings and metadata are dictionaries. - """ - # check generation kwargs given as parameters to override the default ones - additional_params = ["n", "stop_words"] - check_generation_params(generation_kwargs, additional_params) - - # update generation kwargs by merging with the default ones - generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} - num_responses = generation_kwargs.pop("n", 1) - generation_kwargs.setdefault("stop_sequences", []).extend(generation_kwargs.pop("stop_words", [])) - - if self.tokenizer is None: - raise RuntimeError("Please call warm_up() before running LLM inference.") - - prompt_token_count = len(self.tokenizer.encode(prompt, add_special_tokens=False)) - - if self.streaming_callback: - if num_responses > 1: - raise ValueError("Cannot stream multiple responses, please set n=1.") - - return self._run_streaming(prompt, prompt_token_count, generation_kwargs) - - return self._run_non_streaming(prompt, prompt_token_count, num_responses, generation_kwargs) - - def _run_streaming(self, prompt: str, prompt_token_count: int, generation_kwargs: Dict[str, Any]): - res_chunk: Iterable[TextGenerationStreamResponse] = self.client.text_generation( - prompt, details=True, stream=True, **generation_kwargs - ) - chunks: List[StreamingChunk] = [] - # pylint: disable=not-an-iterable - for chunk in res_chunk: - token: Token = chunk.token - if token.special: - continue - chunk_metadata = {**asdict(token), **(asdict(chunk.details) if chunk.details else {})} - stream_chunk = StreamingChunk(token.text, chunk_metadata) - chunks.append(stream_chunk) - self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method) - metadata = { - "finish_reason": chunks[-1].metadata.get("finish_reason", None), - "model": self.client.model, - "usage": { - "completion_tokens": chunks[-1].metadata.get("generated_tokens", 0), - "prompt_tokens": prompt_token_count, - "total_tokens": prompt_token_count + chunks[-1].metadata.get("generated_tokens", 0), - }, - } - return {"replies": ["".join([chunk.content for chunk in chunks])], "metadata": [metadata]} - - def _run_non_streaming( - self, prompt: str, prompt_token_count: int, num_responses: int, generation_kwargs: Dict[str, Any] - ): - responses: List[str] = [] - all_metadata: List[Dict[str, Any]] = [] - for _i in range(num_responses): - tgr: TextGenerationResponse = self.client.text_generation(prompt, details=True, **generation_kwargs) - all_metadata.append( - { - "model": self.client.model, - "index": _i, - "finish_reason": tgr.details.finish_reason.value, - "usage": { - "completion_tokens": len(tgr.details.tokens), - "prompt_tokens": prompt_token_count, - "total_tokens": prompt_token_count + len(tgr.details.tokens), - }, - } - ) - responses.append(tgr.generated_text) - return {"replies": responses, "metadata": all_metadata} diff --git a/haystack/preview/components/generators/openai.py b/haystack/preview/components/generators/openai.py deleted file mode 100644 index 34316637c1..0000000000 --- a/haystack/preview/components/generators/openai.py +++ /dev/null @@ -1,290 +0,0 @@ -import dataclasses -import logging -import os -from typing import Optional, List, Callable, Dict, Any - -import openai -from openai.openai_object import OpenAIObject - -from haystack.preview import component, default_from_dict, default_to_dict -from haystack.preview.components.generators.utils import serialize_callback_handler, deserialize_callback_handler -from haystack.preview.dataclasses import StreamingChunk, ChatMessage - -logger = logging.getLogger(__name__) - - -API_BASE_URL = "https://api.openai.com/v1" - - -@component -class GPTGenerator: - """ - Enables text generation using OpenAI's large language models (LLMs). It supports gpt-4 and gpt-3.5-turbo - family of models. - - Users can pass any text generation parameters valid for the `openai.ChatCompletion.create` method - directly to this component via the `**generation_kwargs` parameter in __init__ or the `**generation_kwargs` - parameter in `run` method. - - For more details on the parameters supported by the OpenAI API, refer to the OpenAI - [documentation](https://platform.openai.com/docs/api-reference/chat). - - ```python - from haystack.preview.components.generators import GPTGenerator - client = GPTGenerator() - response = client.run("What's Natural Language Processing? Be brief.") - print(response) - - >> {'replies': ['Natural Language Processing (NLP) is a branch of artificial intelligence that focuses on - >> the interaction between computers and human language. It involves enabling computers to understand, interpret, - >> and respond to natural human language in a way that is both meaningful and useful.'], 'metadata': [{'model': - >> 'gpt-3.5-turbo-0613', 'index': 0, 'finish_reason': 'stop', 'usage': {'prompt_tokens': 16, - >> 'completion_tokens': 49, 'total_tokens': 65}}]} - ``` - - Key Features and Compatibility: - - **Primary Compatibility**: Designed to work seamlessly with gpt-4, gpt-3.5-turbo family of models. - - **Streaming Support**: Supports streaming responses from the OpenAI API. - - **Customizability**: Supports all parameters supported by the OpenAI API. - - Input and Output Format: - - **String Format**: This component uses the strings for both input and output. - """ - - def __init__( - self, - api_key: Optional[str] = None, - model_name: str = "gpt-3.5-turbo", - streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, - api_base_url: str = API_BASE_URL, - system_prompt: Optional[str] = None, - generation_kwargs: Optional[Dict[str, Any]] = None, - ): - """ - Creates an instance of GPTGenerator. Unless specified otherwise in the `model_name`, this is for OpenAI's - GPT-3.5 model. - - :param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the - environment variable OPENAI_API_KEY (recommended). - :param model_name: The name of the model to use. - :param streaming_callback: A callback function that is called when a new token is received from the stream. - The callback function accepts StreamingChunk as an argument. - :param api_base_url: The OpenAI API Base url, defaults to `https://api.openai.com/v1`. - :param system_prompt: The system prompt to use for text generation. If not provided, the system prompt is - omitted, and the default system prompt of the model is used. - :param generation_kwargs: Other parameters to use for the model. These parameters are all sent directly to - the OpenAI endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat) for - more details. - Some of the supported parameters: - - `max_tokens`: The maximum number of tokens the output text can have. - - `temperature`: What sampling temperature to use. Higher values mean the model will take more risks. - Try 0.9 for more creative applications and 0 (argmax sampling) for ones with a well-defined answer. - - `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model - considers the results of the tokens with top_p probability mass. So, 0.1 means only the tokens - comprising the top 10% probability mass are considered. - - `n`: How many completions to generate for each prompt. For example, if the LLM gets 3 prompts and n is 2, - it will generate two completions for each of the three prompts, ending up with 6 completions in total. - - `stop`: One or more sequences after which the LLM should stop generating tokens. - - `presence_penalty`: What penalty to apply if a token is already present at all. Bigger values mean - the model will be less likely to repeat the same token in the text. - - `frequency_penalty`: What penalty to apply if a token has already been generated in the text. - Bigger values mean the model will be less likely to repeat the same token in the text. - - `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the - values are the bias to add to that token. - """ - # if the user does not provide the API key, check if it is set in the module client - api_key = api_key or openai.api_key - if api_key is None: - try: - api_key = os.environ["OPENAI_API_KEY"] - except KeyError as e: - raise ValueError( - "GPTGenerator expects an OpenAI API key. " - "Set the OPENAI_API_KEY environment variable (recommended) or pass it explicitly." - ) from e - openai.api_key = api_key - - self.model_name = model_name - self.generation_kwargs = generation_kwargs or {} - self.system_prompt = system_prompt - self.streaming_callback = streaming_callback - - self.api_base_url = api_base_url - openai.api_base = api_base_url - - def _get_telemetry_data(self) -> Dict[str, Any]: - """ - Data that is sent to Posthog for usage analytics. - """ - return {"model": self.model_name} - - def to_dict(self) -> Dict[str, Any]: - """ - Serialize this component to a dictionary. - :return: The serialized component as a dictionary. - """ - callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None - return default_to_dict( - self, - model_name=self.model_name, - streaming_callback=callback_name, - api_base_url=self.api_base_url, - generation_kwargs=self.generation_kwargs, - system_prompt=self.system_prompt, - ) - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "GPTGenerator": - """ - Deserialize this component from a dictionary. - :param data: The dictionary representation of this component. - :return: The deserialized component instance. - """ - init_params = data.get("init_parameters", {}) - serialized_callback_handler = init_params.get("streaming_callback") - if serialized_callback_handler: - data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler) - return default_from_dict(cls, data) - - @component.output_types(replies=List[str], metadata=List[Dict[str, Any]]) - def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): - """ - Invoke the text generation inference based on the provided messages and generation parameters. - - :param prompt: The string prompt to use for text generation. - :param generation_kwargs: Additional keyword arguments for text generation. These parameters will - potentially override the parameters passed in the __init__ method. - For more details on the parameters supported by the OpenAI API, refer to the - OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat/create). - :return: A list of strings containing the generated responses and a list of dictionaries containing the metadata - for each response. - """ - message = ChatMessage.from_user(prompt) - if self.system_prompt: - messages = [ChatMessage.from_system(self.system_prompt), message] - else: - messages = [message] - - # update generation kwargs by merging with the generation kwargs passed to the run method - generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} - - # adapt ChatMessage(s) to the format expected by the OpenAI API - openai_formatted_messages = self._convert_to_openai_format(messages) - - completion = openai.ChatCompletion.create( - model=self.model_name, - messages=openai_formatted_messages, - stream=self.streaming_callback is not None, - **generation_kwargs, - ) - - completions: List[ChatMessage] - if self.streaming_callback: - num_responses = generation_kwargs.pop("n", 1) - if num_responses > 1: - raise ValueError("Cannot stream multiple responses, please set n=1.") - chunks: List[StreamingChunk] = [] - chunk = None - for chunk in completion: - if chunk.choices: - chunk_delta: StreamingChunk = self._build_chunk(chunk, chunk.choices[0]) - chunks.append(chunk_delta) - self.streaming_callback(chunk_delta) # invoke callback with the chunk_delta - completions = [self._connect_chunks(chunk, chunks)] - else: - completions = [self._build_message(completion, choice) for choice in completion.choices] - - # before returning, do post-processing of the completions - for completion in completions: - self._check_finish_reason(completion) - - return { - "replies": [message.content for message in completions], - "metadata": [message.metadata for message in completions], - } - - def _convert_to_openai_format(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]: - """ - Converts the list of ChatMessage to the list of messages in the format expected by the OpenAI API. - :param messages: The list of ChatMessage. - :return: The list of messages in the format expected by the OpenAI API. - """ - openai_chat_message_format = {"role", "content", "name"} - openai_formatted_messages = [] - for m in messages: - message_dict = dataclasses.asdict(m) - filtered_message = {k: v for k, v in message_dict.items() if k in openai_chat_message_format and v} - openai_formatted_messages.append(filtered_message) - return openai_formatted_messages - - def _connect_chunks(self, chunk: OpenAIObject, chunks: List[StreamingChunk]) -> ChatMessage: - """ - Connects the streaming chunks into a single ChatMessage. - """ - complete_response = ChatMessage.from_assistant("".join([chunk.content for chunk in chunks])) - complete_response.metadata.update( - { - "model": chunk.model, - "index": 0, - "finish_reason": chunk.choices[0].finish_reason, - "usage": {}, # we don't have usage data for streaming responses - } - ) - return complete_response - - def _build_message(self, completion: OpenAIObject, choice: OpenAIObject) -> ChatMessage: - """ - Converts the response from the OpenAI API to a ChatMessage. - :param completion: The completion returned by the OpenAI API. - :param choice: The choice returned by the OpenAI API. - :return: The ChatMessage. - """ - message: OpenAIObject = choice.message - content = dict(message.function_call) if choice.finish_reason == "function_call" else message.content - chat_message = ChatMessage.from_assistant(content) - chat_message.metadata.update( - { - "model": completion.model, - "index": choice.index, - "finish_reason": choice.finish_reason, - "usage": dict(completion.usage.items()), - } - ) - return chat_message - - def _build_chunk(self, chunk: OpenAIObject, choice: OpenAIObject) -> StreamingChunk: - """ - Converts the response from the OpenAI API to a StreamingChunk. - :param chunk: The chunk returned by the OpenAI API. - :param choice: The choice returned by the OpenAI API. - :return: The StreamingChunk. - """ - has_content = bool(hasattr(choice.delta, "content") and choice.delta.content) - if has_content: - content = choice.delta.content - elif hasattr(choice.delta, "function_call"): - content = str(choice.delta.function_call) - else: - content = "" - chunk_message = StreamingChunk(content) - chunk_message.metadata.update( - {"model": chunk.model, "index": choice.index, "finish_reason": choice.finish_reason} - ) - return chunk_message - - def _check_finish_reason(self, message: ChatMessage) -> None: - """ - Check the `finish_reason` returned with the OpenAI completions. - If the `finish_reason` is `length`, log a warning to the user. - :param message: The message returned by the LLM. - """ - if message.metadata["finish_reason"] == "length": - logger.warning( - "The completion for index %s has been truncated before reaching a natural stopping point. " - "Increase the max_tokens parameter to allow for longer completions.", - message.metadata["index"], - ) - if message.metadata["finish_reason"] == "content_filter": - logger.warning( - "The completion for index %s has been truncated due to the content filter.", message.metadata["index"] - ) diff --git a/haystack/preview/components/generators/utils.py b/haystack/preview/components/generators/utils.py deleted file mode 100644 index 397009e4e0..0000000000 --- a/haystack/preview/components/generators/utils.py +++ /dev/null @@ -1,49 +0,0 @@ -import inspect -import sys -from typing import Optional, Callable - -from haystack.preview import DeserializationError -from haystack.preview.dataclasses import StreamingChunk - - -def default_streaming_callback(chunk: StreamingChunk) -> None: - """ - Default callback function for streaming responses. - Prints the tokens of the first completion to stdout as soon as they are received - """ - print(chunk.content, flush=True, end="") - - -def serialize_callback_handler(streaming_callback: Callable[[StreamingChunk], None]) -> str: - """ - Serializes the streaming callback handler. - :param streaming_callback: The streaming callback handler function - :return: The full path of the streaming callback handler function - """ - module = inspect.getmodule(streaming_callback) - - # Get the full package path of the function - if module is not None: - full_path = f"{module.__name__}.{streaming_callback.__name__}" - else: - full_path = streaming_callback.__name__ - return full_path - - -def deserialize_callback_handler(callback_name: str) -> Optional[Callable[[StreamingChunk], None]]: - """ - Deserializes the streaming callback handler. - :param callback_name: The full path of the streaming callback handler function - :return: The streaming callback handler function - :raises DeserializationError: If the streaming callback handler function cannot be found - """ - parts = callback_name.split(".") - module_name = ".".join(parts[:-1]) - function_name = parts[-1] - module = sys.modules.get(module_name, None) - if not module: - raise DeserializationError(f"Could not locate the module of the streaming callback: {module_name}") - streaming_callback = getattr(module, function_name, None) - if not streaming_callback: - raise DeserializationError(f"Could not locate the streaming callback: {function_name}") - return streaming_callback diff --git a/haystack/preview/components/preprocessors/__init__.py b/haystack/preview/components/preprocessors/__init__.py deleted file mode 100644 index 2045621e22..0000000000 --- a/haystack/preview/components/preprocessors/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from haystack.preview.components.preprocessors.document_cleaner import DocumentCleaner -from haystack.preview.components.preprocessors.document_splitter import DocumentSplitter - -__all__ = ["DocumentSplitter", "DocumentCleaner"] diff --git a/haystack/preview/components/preprocessors/document_cleaner.py b/haystack/preview/components/preprocessors/document_cleaner.py deleted file mode 100644 index 370173baca..0000000000 --- a/haystack/preview/components/preprocessors/document_cleaner.py +++ /dev/null @@ -1,229 +0,0 @@ -import logging -import re -from copy import deepcopy -from functools import partial, reduce -from itertools import chain -from typing import Generator, List, Optional, Set - -from haystack.preview import Document, component - -logger = logging.getLogger(__name__) - - -@component -class DocumentCleaner: - """ - Makes text documents more readable by removing extra whitespaces, empty lines, specified substrings, regexes, page headers and footers (in this order). - This is useful for preparing the documents for further processing by LLMs. - - Example usage in an indexing pipeline: - - ```python - document_store = InMemoryDocumentStore() - p = Pipeline() - p.add_component(instance=TextFileToDocument(), name="text_file_converter") - p.add_component(instance=DocumentCleaner(), name="cleaner") - p.add_component(instance=TextDocumentSplitter(split_by="sentence", split_length=1), name="splitter") - p.add_component(instance=DocumentWriter(document_store=document_store), name="writer") - p.connect("text_file_converter.documents", "cleaner.documents") - p.connect("cleaner.documents", "splitter.documents") - p.connect("splitter.documents", "writer.documents") - ``` - """ - - def __init__( - self, - remove_empty_lines: bool = True, - remove_extra_whitespaces: bool = True, - remove_repeated_substrings: bool = False, - remove_substrings: Optional[List[str]] = None, - remove_regex: Optional[str] = None, - ): - """ - :param remove_empty_lines: Whether to remove empty lines. - :param remove_extra_whitespaces: Whether to remove extra whitespaces. - :param remove_repeated_substrings: Whether to remove repeated substrings (headers/footers) from pages. - Pages in the text need to be separated by form feed character "\f", - which is supported by TextFileToDocument and AzureOCRDocumentConverter. - :param remove_substrings: List of substrings to remove from the text. - :param remove_regex: Regex to match and replace substrings by "". - """ - - self.remove_empty_lines = remove_empty_lines - self.remove_extra_whitespaces = remove_extra_whitespaces - self.remove_repeated_substrings = remove_repeated_substrings - self.remove_substrings = remove_substrings - self.remove_regex = remove_regex - - @component.output_types(documents=List[Document]) - def run(self, documents: List[Document]): - """ - Run the DocumentCleaner on the given list of documents. - The documents' metadata remain unchanged. - """ - if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): - raise TypeError("DocumentCleaner expects a List of Documents as input.") - - cleaned_docs = [] - for doc in documents: - if doc.content is None: - logger.warning( - "DocumentCleaner only cleans text documents but document.content for document ID %s is None.", - doc.id, - ) - cleaned_docs.append(doc) - continue - text = doc.content - - if self.remove_extra_whitespaces: - text = self._remove_extra_whitespaces(text) - if self.remove_empty_lines: - text = self._remove_empty_lines(text) - if self.remove_substrings: - text = self._remove_substrings(text, self.remove_substrings) - if self.remove_regex: - text = self._remove_regex(text, self.remove_regex) - if self.remove_repeated_substrings: - text = self._remove_repeated_substrings(text) - - cleaned_docs.append(Document(content=text, meta=deepcopy(doc.meta))) - - return {"documents": cleaned_docs} - - def _remove_empty_lines(self, text: str) -> str: - """ - Remove empty lines and lines that contain nothing but whitespaces from text. - :param text: Text to clean. - :param return: The text without empty lines. - """ - lines = text.split("\n") - non_empty_lines = filter(lambda line: line.strip() != "", lines) - return "\n".join(non_empty_lines) - - def _remove_extra_whitespaces(self, text: str) -> str: - """ - Remove extra whitespaces from text. - :param text: Text to clean. - :param return: The text without extra whitespaces. - """ - return re.sub(r"\s\s+", " ", text).strip() - - def _remove_regex(self, text: str, regex: str) -> str: - """ - Remove substrings that match the specified regex from the text. - :param text: Text to clean. - :param regex: Regex to match and replace substrings by "". - :param return: The text without any substrings that match the regex. - """ - return re.sub(regex, "", text).strip() - - def _remove_substrings(self, text: str, substrings: List[str]) -> str: - """ - Remove all specified substrings from the text. - :param text: Text to clean. - :param substrings: Substrings to remove. - :return: The text without the specified substrings. - """ - for substring in substrings: - text = text.replace(substring, "") - return text - - def _remove_repeated_substrings(self, text: str) -> str: - """ - Remove any substrings from the text that occur repeatedly on every page. For example headers or footers. - Pages in the text need to be separated by form feed character "\f". - :param text: Text to clean. - :return: The text without the repeated substrings. - """ - return self._find_and_remove_header_footer( - text, n_chars=300, n_first_pages_to_ignore=1, n_last_pages_to_ignore=1 - ) - - def _find_and_remove_header_footer( - self, text: str, n_chars: int, n_first_pages_to_ignore: int, n_last_pages_to_ignore: int - ) -> str: - """ - Heuristic to find footers and headers across different pages by searching for the longest common string. - Pages in the text need to be separated by form feed character "\f". - For headers, we only search in the first n_chars characters (for footer: last n_chars). - Note: This heuristic uses exact matches and therefore works well for footers like "Copyright 2019 by XXX", - but won't detect "Page 3 of 4" or similar. - - :param n_chars: The number of first/last characters where the header/footer shall be searched in. - :param n_first_pages_to_ignore: The number of first pages to ignore (e.g. TOCs often don't contain footer/header). - :param n_last_pages_to_ignore: The number of last pages to ignore. - :return: The text without the found headers and footers. - """ - - pages = text.split("\f") - - # header - start_of_pages = [p[:n_chars] for p in pages[n_first_pages_to_ignore:-n_last_pages_to_ignore]] - found_header = self._find_longest_common_ngram(start_of_pages) - if found_header: - pages = [page.replace(found_header, "") for page in pages] - - # footer - end_of_pages = [p[-n_chars:] for p in pages[n_first_pages_to_ignore:-n_last_pages_to_ignore]] - found_footer = self._find_longest_common_ngram(end_of_pages) - if found_footer: - pages = [page.replace(found_footer, "") for page in pages] - - logger.debug("Removed header '%s' and footer '%s' in document", found_header, found_footer) - text = "\f".join(pages) - return text - - def _ngram(self, seq: str, n: int) -> Generator[str, None, None]: - """ - Return all ngrams of length n from a text sequence. Each ngram consists of n words split by whitespace. - :param seq: The sequence to generate ngrams from. - :param n: The length of the ngrams to generate. - :return: A Generator generating all ngrams of length n from the given sequence. - """ - - # In order to maintain the original whitespace, but still consider \n and \t for n-gram tokenization, - # we add a space here and remove it after creation of the ngrams again (see below) - seq = seq.replace("\n", " \n") - seq = seq.replace("\t", " \t") - - words = seq.split(" ") - ngrams = ( - " ".join(words[i : i + n]).replace(" \n", "\n").replace(" \t", "\t") for i in range(0, len(words) - n + 1) - ) - - return ngrams - - def _allngram(self, seq: str, min_ngram: int, max_ngram: int) -> Set[str]: - """ - Generates all possible ngrams from a given sequence of text. - Considering all ngram lengths between the minimum and maximum length. - - :param seq: The sequence to generate ngrams from. - :param min_ngram: The minimum length of ngram to consider. - :param max_ngram: The maximum length of ngram to consider. - :return: A set of all ngrams from the given sequence. - """ - lengths = range(min_ngram, max_ngram) if max_ngram else range(min_ngram, len(seq)) - ngrams = map(partial(self._ngram, seq), lengths) - res = set(chain.from_iterable(ngrams)) - return res - - def _find_longest_common_ngram(self, sequences: List[str], min_ngram: int = 3, max_ngram: int = 30) -> str: - """ - Find the longest common ngram across a list of text sequences (e.g. start of pages). - Considering all ngram lengths between the minimum and maximum length. Helpful for finding footers, headers etc. - Empty sequences are ignored. - - :param sequences: The list of strings that shall be searched for common n_grams. - :param max_ngram: The maximum length of ngram to consider. - :param min_ngram: The minimum length of ngram to consider. - :return: The longest ngram that all sequences have in common. - """ - sequences = [s for s in sequences if s] # filter empty sequences - if not sequences: - return "" - seqs_ngrams = map(partial(self._allngram, min_ngram=min_ngram, max_ngram=max_ngram), sequences) - intersection = reduce(set.intersection, seqs_ngrams) - - longest = max(intersection, key=len, default="") - return longest if longest.strip() else "" diff --git a/haystack/preview/components/preprocessors/document_splitter.py b/haystack/preview/components/preprocessors/document_splitter.py deleted file mode 100644 index ecb8a3f11f..0000000000 --- a/haystack/preview/components/preprocessors/document_splitter.py +++ /dev/null @@ -1,91 +0,0 @@ -from copy import deepcopy -from typing import List, Literal - -from more_itertools import windowed - -from haystack.preview import component, Document - - -@component -class DocumentSplitter: - """ - Splits a list of text documents into a list of text documents with shorter texts. - This is useful for splitting documents with long texts that otherwise would not fit into the maximum text length of language models. - """ - - def __init__( - self, split_by: Literal["word", "sentence", "passage"] = "word", split_length: int = 200, split_overlap: int = 0 - ): - """ - :param split_by: The unit by which the document should be split. Choose from "word" for splitting by " ", - "sentence" for splitting by ".", or "passage" for splitting by "\n\n". - :param split_length: The maximum number of units in each split. - :param split_overlap: The number of units that each split should overlap. - """ - - self.split_by = split_by - if split_by not in ["word", "sentence", "passage"]: - raise ValueError("split_by must be one of 'word', 'sentence' or 'passage'.") - if split_length <= 0: - raise ValueError("split_length must be greater than 0.") - self.split_length = split_length - if split_overlap < 0: - raise ValueError("split_overlap must be greater than or equal to 0.") - self.split_overlap = split_overlap - - @component.output_types(documents=List[Document]) - def run(self, documents: List[Document]): - """ - Splits the documents by split_by after split_length units with an overlap of split_overlap units. - Returns a list of documents with the split texts. - A metadata field "source_id" is added to each document to keep track of the original document that was split. - Other metadata are copied from the original document. - :param documents: The documents to split. - :return: A list of documents with the split texts. - """ - - if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): - raise TypeError("DocumentSplitter expects a List of Documents as input.") - - split_docs = [] - for doc in documents: - if doc.content is None: - raise ValueError( - f"DocumentSplitter only works with text documents but document.content for document ID {doc.id} is None." - ) - units = self._split_into_units(doc.content, self.split_by) - text_splits = self._concatenate_units(units, self.split_length, self.split_overlap) - metadata = deepcopy(doc.meta) - metadata["source_id"] = doc.id - split_docs += [Document(content=txt, meta=metadata) for txt in text_splits] - return {"documents": split_docs} - - def _split_into_units(self, text: str, split_by: Literal["word", "sentence", "passage"]) -> List[str]: - if split_by == "passage": - split_at = "\n\n" - elif split_by == "sentence": - split_at = "." - elif split_by == "word": - split_at = " " - else: - raise NotImplementedError( - "DocumentSplitter only supports 'passage', 'sentence' or 'word' split_by options." - ) - units = text.split(split_at) - # Add the delimiter back to all units except the last one - for i in range(len(units) - 1): - units[i] += split_at - return units - - def _concatenate_units(self, elements: List[str], split_length: int, split_overlap: int) -> List[str]: - """ - Concatenates the elements into parts of split_length units. - """ - text_splits = [] - segments = windowed(elements, n=split_length, step=split_length - split_overlap) - for seg in segments: - current_units = [unit for unit in seg if unit is not None] - txt = "".join(current_units) - if len(txt) > 0: - text_splits.append(txt) - return text_splits diff --git a/haystack/preview/components/rankers/__init__.py b/haystack/preview/components/rankers/__init__.py deleted file mode 100644 index f188f01225..0000000000 --- a/haystack/preview/components/rankers/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from haystack.preview.components.rankers.meta_field import MetaFieldRanker -from haystack.preview.components.rankers.transformers_similarity import TransformersSimilarityRanker - -__all__ = ["MetaFieldRanker", "TransformersSimilarityRanker"] diff --git a/haystack/preview/components/rankers/meta_field.py b/haystack/preview/components/rankers/meta_field.py deleted file mode 100644 index 7c7e4a73cc..0000000000 --- a/haystack/preview/components/rankers/meta_field.py +++ /dev/null @@ -1,180 +0,0 @@ -import logging -import warnings -from collections import defaultdict -from typing import List, Dict, Any, Optional, Literal - -from haystack.preview import ComponentError, Document, component, default_to_dict - -logger = logging.getLogger(__name__) - - -@component -class MetaFieldRanker: - """ - Ranks Documents based on the value of their specific metadata field. The ranking is done in a descending order. - - Usage example: - ``` - from haystack.preview import Document - from haystack.preview.components.rankers import MetaFieldRanker - - ranker = MetaFieldRanker(metadata_field="rating") - docs = [ - Document(text="Paris", metadata={"rating": 1.3}), - Document(text="Berlin", metadata={"rating": 0.7}), - Document(text="Barcelona", metadata={"rating": 2.1}), - ] - - output = ranker.run(documents=docs) - docs = output["documents"] - assert docs[0].text == "Barcelona" - """ - - def __init__( - self, - metadata_field: str, - weight: float = 1.0, - top_k: Optional[int] = None, - ranking_mode: Literal["reciprocal_rank_fusion", "linear_score"] = "reciprocal_rank_fusion", - ): - """ - Creates an instance of MetaFieldRanker. - - :param metadata_field: The name of the metadata field to rank by. - :param weight: In range [0,1]. - 0 disables ranking by a metadata field. - 0.5 content and metadata fields have the same impact for the ranking. - 1 means ranking by a metadata field only. The highest value comes first. - :param top_k: The maximum number of Documents you want the Ranker to return per query. - :param ranking_mode: The mode used to combine the Retriever's and Ranker's scores. - Possible values are 'reciprocal_rank_fusion' (default) and 'linear_score'. - Use the 'score' mode only with Retrievers or Rankers that return a score in range [0,1]. - """ - - self.metadata_field = metadata_field - self.weight = weight - self.top_k = top_k - self.ranking_mode = ranking_mode - - if self.weight < 0 or self.weight > 1: - raise ValueError( - """ - Parameter must be in range [0,1] but is currently set to '{}'.\n - '0' disables sorting by a metadata field, '0.5' assigns equal weight to the previous relevance scores and the metadata field, and '1' ranks by the metadata field only.\n - Change the parameter to a value in range 0 to 1 when initializing the MetaFieldRanker. - """.format( - self.weight - ) - ) - - if self.ranking_mode not in ["reciprocal_rank_fusion", "linear_score"]: - raise ValueError( - """ - The value of parameter must be 'reciprocal_rank_fusion' or 'linear_score', but is currently set to '{}'. \n - Change the value to 'reciprocal_rank_fusion' or 'linear_score' when initializing the MetaFieldRanker. - """.format( - self.ranking_mode - ) - ) - - def to_dict(self) -> Dict[str, Any]: - """ - Serialize object to a dictionary. - """ - return default_to_dict( - self, - metadata_field=self.metadata_field, - weight=self.weight, - top_k=self.top_k, - ranking_mode=self.ranking_mode, - ) - - @component.output_types(documents=List[Document]) - def run(self, documents: List[Document], top_k: Optional[int] = None): - """ - Use this method to rank a list of Documents based on the selected metadata field by: - 1. Sorting the Documents by the metadata field in descending order. - 2. Merging the scores from the metadata field with the scores from the previous component according to the strategy and weight provided. - 3. Returning the top-k documents. - - :param documents: Documents to be ranked. - :param top_k: (optional) The number of Documents you want the Ranker to return. If not provided, the Ranker returns all Documents it received. - """ - if not documents: - return {"documents": []} - - if top_k is None: - top_k = self.top_k - elif top_k <= 0: - raise ValueError(f"top_k must be > 0, but got {top_k}") - - try: - sorted_by_metadata = sorted(documents, key=lambda doc: doc.meta[self.metadata_field], reverse=True) - except KeyError: - raise ComponentError( - """ - The parameter is currently set to '{}' but the Documents {} don't have this metadata key.\n - Double-check the names of the metadata fields in your documents \n - and set to the name of the field that contains the metadata you want to use for ranking. - """.format( - self.metadata_field, ",".join([doc.id for doc in documents if self.metadata_field not in doc.meta]) - ) - ) - - if self.weight > 0: - sorted_documents = self._merge_scores(documents, sorted_by_metadata) - return {"documents": sorted_documents[:top_k]} - else: - return {"documents": sorted_by_metadata[:top_k]} - - def _merge_scores(self, documents: List[Document], sorted_documents: List[Document]) -> List[Document]: - """ - Merge scores for Documents sorted both by their content and by their metadata field. - """ - scores_map: Dict = defaultdict(int) - - if self.ranking_mode == "reciprocal_rank_fusion": - for i, (doc, sorted_doc) in enumerate(zip(documents, sorted_documents)): - scores_map[doc.id] += self._calculate_rrf(rank=i) * (1 - self.weight) - scores_map[sorted_doc.id] += self._calculate_rrf(rank=i) * self.weight - elif self.ranking_mode == "linear_score": - for i, (doc, sorted_doc) in enumerate(zip(documents, sorted_documents)): - score = float(0) - if doc.score is None: - warnings.warn("The score wasn't provided; defaulting to 0.") - elif doc.score < 0 or doc.score > 1: - warnings.warn( - "The score {} for Document {} is outside the [0,1] range; defaulting to 0".format( - doc.score, doc.id - ) - ) - else: - score = doc.score - - scores_map[doc.id] += score * (1 - self.weight) - scores_map[sorted_doc.id] += self._calc_linear_score(rank=i, amount=len(sorted_documents)) * self.weight - - for doc in documents: - doc.score = scores_map[doc.id] - - new_sorted_documents = sorted(documents, key=lambda doc: doc.score if doc.score else -1, reverse=True) - return new_sorted_documents - - @staticmethod - def _calculate_rrf(rank: int, k: int = 61) -> float: - """ - Calculates the reciprocal rank fusion. The constant K is set to 61 (60 was suggested by the original paper, - plus 1 as python lists are 0-based and the [paper](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf) used 1-based ranking). - """ - return 1 / (k + rank) - - @staticmethod - def _calc_linear_score(rank: int, amount: int) -> float: - """ - Calculate the metadata field score as a linear score between the greatest and the lowest score in the list. - This linear scaling is useful for: - - Reducing the effect of outliers - - Creating scores that are meaningfully distributed in the range [0,1], - similar to scores coming from a Retriever or Ranker. - """ - return (amount - rank) / amount diff --git a/haystack/preview/components/rankers/transformers_similarity.py b/haystack/preview/components/rankers/transformers_similarity.py deleted file mode 100644 index 0c4176a7a1..0000000000 --- a/haystack/preview/components/rankers/transformers_similarity.py +++ /dev/null @@ -1,134 +0,0 @@ -import logging -from pathlib import Path -from typing import List, Union, Dict, Any, Optional - -from haystack.preview import ComponentError, Document, component, default_to_dict -from haystack.preview.lazy_imports import LazyImport - -logger = logging.getLogger(__name__) - - -with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import: - import torch - from transformers import AutoModelForSequenceClassification, AutoTokenizer - - -@component -class TransformersSimilarityRanker: - """ - Ranks Documents based on their similarity to the query. - It uses a pre-trained cross-encoder model (from the Hugging Face Hub) to embed the query and the Documents. - - Usage example: - ``` - from haystack.preview import Document - from haystack.preview.components.rankers import TransformersSimilarityRanker - - ranker = TransformersSimilarityRanker() - docs = [Document(content="Paris"), Document(content="Berlin")] - query = "City in Germany" - output = ranker.run(query=query, documents=docs) - docs = output["documents"] - assert len(docs) == 2 - assert docs[0].content == "Berlin" - ``` - """ - - def __init__( - self, - model_name_or_path: Union[str, Path] = "cross-encoder/ms-marco-MiniLM-L-6-v2", - device: str = "cpu", - token: Union[bool, str, None] = None, - top_k: int = 10, - ): - """ - Creates an instance of TransformersSimilarityRanker. - - :param model_name_or_path: The name or path of a pre-trained cross-encoder model - from the Hugging Face Hub. - :param device: The torch device (for example, cuda:0, cpu, mps) to which you want to limit model inference. - :param token: The API token used to download private models from Hugging Face. - If this parameter is set to `True`, the token generated when running - `transformers-cli login` (stored in ~/.huggingface) is used. - :param top_k: The maximum number of Documents to return per query. - """ - torch_and_transformers_import.check() - - self.model_name_or_path = model_name_or_path - if top_k <= 0: - raise ValueError(f"top_k must be > 0, but got {top_k}") - self.top_k = top_k - self.device = device - self.token = token - self.model = None - self.tokenizer = None - - def _get_telemetry_data(self) -> Dict[str, Any]: - """ - Data that is sent to Posthog for usage analytics. - """ - return {"model": str(self.model_name_or_path)} - - def warm_up(self): - """ - Warm up the model and tokenizer used for scoring the Documents. - """ - if self.model_name_or_path and not self.model: - self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name_or_path, token=self.token) - self.model = self.model.to(self.device) - self.model.eval() - self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, token=self.token) - - def to_dict(self) -> Dict[str, Any]: - """ - Serialize this component to a dictionary. - """ - return default_to_dict( - self, - device=self.device, - model_name_or_path=self.model_name_or_path, - token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens - top_k=self.top_k, - ) - - @component.output_types(documents=List[Document]) - def run(self, query: str, documents: List[Document], top_k: Optional[int] = None): - """ - Returns a list of Documents ranked by their similarity to the given query. - - :param query: Query string. - :param documents: List of Documents. - :param top_k: The maximum number of Documents you want the Ranker to return. - :return: List of Documents sorted by their similarity to the query with the most similar Documents appearing first. - """ - if not documents: - return {"documents": []} - - if top_k is None: - top_k = self.top_k - - elif top_k <= 0: - raise ValueError(f"top_k must be > 0, but got {top_k}") - - # If a model path is provided but the model isn't loaded - if self.model_name_or_path and not self.model: - raise ComponentError( - f"The component {self.__class__.__name__} wasn't warmed up. Run 'warm_up()' before calling 'run()'." - ) - - query_doc_pairs = [[query, doc.content] for doc in documents] - features = self.tokenizer( - query_doc_pairs, padding=True, truncation=True, return_tensors="pt" - ).to( # type: ignore - self.device - ) - with torch.inference_mode(): - similarity_scores = self.model(**features).logits.squeeze() # type: ignore - - _, sorted_indices = torch.sort(similarity_scores, descending=True) - ranked_docs = [] - for sorted_index_tensor in sorted_indices: - i = sorted_index_tensor.item() - documents[i].score = similarity_scores[i].item() - ranked_docs.append(documents[i]) - return {"documents": ranked_docs[:top_k]} diff --git a/haystack/preview/components/readers/__init__.py b/haystack/preview/components/readers/__init__.py deleted file mode 100644 index e48da38979..0000000000 --- a/haystack/preview/components/readers/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from haystack.preview.components.readers.extractive import ExtractiveReader - -__all__ = ["ExtractiveReader"] diff --git a/haystack/preview/components/readers/extractive.py b/haystack/preview/components/readers/extractive.py deleted file mode 100644 index 5c763eef94..0000000000 --- a/haystack/preview/components/readers/extractive.py +++ /dev/null @@ -1,421 +0,0 @@ -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union -import math -import warnings -import logging -import os - -from haystack.preview import component, default_to_dict, ComponentError, Document, ExtractedAnswer -from haystack.preview.lazy_imports import LazyImport - -with LazyImport("Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import: - from transformers import AutoModelForQuestionAnswering, AutoTokenizer - from tokenizers import Encoding - import torch - - -logger = logging.getLogger(__name__) - - -@component -class ExtractiveReader: - """ - A component that locates and extract answers to a given query from Documents. It's used for performing extractive - QA. The Reader assigns a probability score to every possible answer span independently of other answer spans. - This fixes a common issue of other implementations which make comparisons across documents harder by normalizing - each document's answers independently. - - Example usage: - ```python - p = Pipeline() - p.add_component(instance=InMemoryBM25Retriever(document_store=InMemoryDocumentStore()), name="retriever") - p.add_component(instance=ExtractiveReader(), name="reader") - p.connect("retriever", "reader") - question = "Who lives in Berlin?" - p.run({"retriever": {"query": question}, "reader": {"query": question}}) - ``` - """ - - def __init__( - self, - model_name_or_path: Union[Path, str] = "deepset/roberta-base-squad2-distilled", - device: Optional[str] = None, - token: Union[bool, str, None] = None, - top_k: int = 20, - confidence_threshold: Optional[float] = None, - max_seq_length: int = 384, - stride: int = 128, - max_batch_size: Optional[int] = None, - answers_per_seq: Optional[int] = None, - no_answer: bool = True, - calibration_factor: float = 0.1, - model_kwargs: Optional[Dict[str, Any]] = None, - ) -> None: - """ - Creates an ExtractiveReader - :param model_name_or_path: A Hugging Face transformers question answering model. - Can either be a path to a folder containing the model files or an identifier for the Hugging Face hub. - Default: `'deepset/roberta-base-squad2-distilled'` - :param device: Pytorch device string. Uses GPU by default, if available. - :param token: The API token used to download private models from Hugging Face. - If this parameter is set to `True`, then the token generated when running - `transformers-cli login` (stored in ~/.huggingface) is used. - :param top_k: Number of answers to return per query. - It is required even if confidence_threshold is set. Defaults to 20. - An additional answer with no text is returned if no_answer is set to True (default). - :param confidence_threshold: Returns only answers with the probability score above this threshold. - :param max_seq_length: Maximum number of tokens. - If a sequence exceeds it, the sequence is split. - Default: 384 - :param stride: Number of tokens that overlap when sequence is split because it exceeds max_seq_length. - Default: 128 - :param max_batch_size: Maximum number of samples that are fed through the model at the same time. - :param answers_per_seq: Number of answer candidates to consider per sequence. - This is relevant when a Document was split into multiple sequences because of max_seq_length. - :param no_answer: Whether to return no answer scores. - :param calibration_factor: Factor used for calibrating probability scores. - :param model_kwargs: Additional keyword arguments passed to `AutoModelForQuestionAnswering.from_pretrained` - when loading the model specified in `model_name_or_path`. For details on what kwargs you can pass, - see the model's documentation. - """ - torch_and_transformers_import.check() - self.model_name_or_path = str(model_name_or_path) - self.model = None - self.device = device - self.token = token - self.max_seq_length = max_seq_length - self.top_k = top_k - self.confidence_threshold = confidence_threshold - self.stride = stride - self.max_batch_size = max_batch_size - self.answers_per_seq = answers_per_seq - self.no_answer = no_answer - self.calibration_factor = calibration_factor - self.model_kwargs = model_kwargs or {} - - def _get_telemetry_data(self) -> Dict[str, Any]: - """ - Data that is sent to Posthog for usage analytics. - """ - return {"model": self.model_name_or_path} - - def to_dict(self) -> Dict[str, Any]: - """ - Serialize this component to a dictionary. - """ - return default_to_dict( - self, - model_name_or_path=self.model_name_or_path, - device=self.device, - token=self.token if not isinstance(self.token, str) else None, - max_seq_length=self.max_seq_length, - top_k=self.top_k, - confidence_threshold=self.confidence_threshold, - stride=self.stride, - max_batch_size=self.max_batch_size, - answers_per_seq=self.answers_per_seq, - no_answer=self.no_answer, - calibration_factor=self.calibration_factor, - model_kwargs=self.model_kwargs, - ) - - def warm_up(self): - """ - Loads model and tokenizer - """ - if self.model is None: - if torch.cuda.is_available(): - self.device = self.device or "cuda:0" - elif ( - hasattr(torch.backends, "mps") - and torch.backends.mps.is_available() - and os.getenv("HAYSTACK_MPS_ENABLED", "true") != "false" - ): - self.device = self.device or "mps:0" - else: - self.device = self.device or "cpu:0" - - self.model = AutoModelForQuestionAnswering.from_pretrained( - self.model_name_or_path, token=self.token, **self.model_kwargs - ).to(self.device) - self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, token=self.token) - - def _flatten_documents( - self, queries: List[str], documents: List[List[Document]] - ) -> Tuple[List[str], List[Document], List[int]]: - """ - Flattens queries and Documents so all query-document pairs are arranged along one batch axis. - """ - flattened_queries = [query for documents_, query in zip(documents, queries) for _ in documents_] - flattened_documents = [document for documents_ in documents for document in documents_] - query_ids = [i for i, documents_ in enumerate(documents) for _ in documents_] - return flattened_queries, flattened_documents, query_ids - - def _preprocess( - self, queries: List[str], documents: List[Document], max_seq_length: int, query_ids: List[int], stride: int - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[Encoding], List[int], List[int]]: - """ - Split and tokenize Documents and preserve structures by returning mappings to query and Document IDs. - """ - texts = [] - document_ids = [] - for i, doc in enumerate(documents): - if doc.content is None: - warnings.warn( - f"Document with id {doc.id} was passed to ExtractiveReader. The Document doesn't " - f"contain any text and it will be ignored." - ) - continue - texts.append(doc.content) - document_ids.append(i) - encodings_pt = self.tokenizer( - queries, - [document.content for document in documents], - padding=True, - truncation=True, - max_length=max_seq_length, - return_tensors="pt", - return_overflowing_tokens=True, - stride=stride, - ) - - input_ids = encodings_pt.input_ids.to(self.device) - attention_mask = encodings_pt.attention_mask.to(self.device) - - query_ids = [query_ids[index] for index in encodings_pt.overflow_to_sample_mapping] - document_ids = [document_ids[sample_id] for sample_id in encodings_pt.overflow_to_sample_mapping] - - encodings = encodings_pt.encodings - sequence_ids = torch.tensor( - [[id_ if id_ is not None else -1 for id_ in encoding.sequence_ids] for encoding in encodings] - ).to(self.device) - - return input_ids, attention_mask, sequence_ids, encodings, query_ids, document_ids - - def _postprocess( - self, - start: torch.Tensor, - end: torch.Tensor, - sequence_ids: torch.Tensor, - attention_mask: torch.Tensor, - answers_per_seq: int, - encodings: List[Encoding], - ) -> Tuple[List[List[int]], List[List[int]], torch.Tensor]: - """ - Turn start and end logits into probability scores for each answer span. Unlike most other - implementations, it doesn't normalize the scores to make them easier to compare across different - splits. Returns the top k answer spans. - """ - mask = sequence_ids == 1 - mask = torch.logical_and(mask, attention_mask == 1) - start = torch.where(mask, start, -torch.inf) - end = torch.where(mask, end, -torch.inf) - start = start.unsqueeze(-1) - end = end.unsqueeze(-2) - - logits = start + end # shape: (batch_size, seq_length (start), seq_length (end)) - mask = torch.ones(logits.shape[-2:], dtype=torch.bool, device=self.device) - mask = torch.triu(mask) # End shouldn't be before start - masked_logits = torch.where(mask, logits, -torch.inf) - probabilities = torch.sigmoid(masked_logits * self.calibration_factor) - - flat_probabilities = probabilities.flatten(-2, -1) # necessary for topk - candidates = torch.topk(flat_probabilities, answers_per_seq) - seq_length = logits.shape[-1] - start_candidates = candidates.indices // seq_length # Recover indices from flattening - end_candidates = candidates.indices % seq_length - start_candidates = start_candidates.cpu() - end_candidates = end_candidates.cpu() - - start_candidates_tokens_to_chars = [ - [encoding.token_to_chars(start) for start in candidates] - for candidates, encoding in zip(start_candidates, encodings) - ] - if missing_start_tokens := [ - (batch, index) - for batch, token_to_chars in enumerate(start_candidates_tokens_to_chars) - for index, pair in enumerate(token_to_chars) - if pair is None - ]: - logger.warning("Some start tokens could not be found in the context: %s", missing_start_tokens) - start_candidates_char_indices = [ - [token_to_chars[0] if token_to_chars else None for token_to_chars in candidates] - for candidates in start_candidates_tokens_to_chars - ] - - end_candidates_tokens_to_chars = [ - [encoding.token_to_chars(end) for end in candidates] - for candidates, encoding in zip(end_candidates, encodings) - ] - if missing_end_tokens := [ - (batch, index) - for batch, token_to_chars in enumerate(end_candidates_tokens_to_chars) - for index, pair in enumerate(token_to_chars) - if pair is None - ]: - logger.warning("Some end tokens could not be found in the context: %s", missing_end_tokens) - end_candidates_char_indices = [ - [token_to_chars[1] if token_to_chars else None for token_to_chars in candidates] - for candidates in end_candidates_tokens_to_chars - ] - - probabilities = candidates.values.cpu() - - return start_candidates_char_indices, end_candidates_char_indices, probabilities - - def _nest_answers( - self, - start: List[List[int]], - end: List[List[int]], - probabilities: torch.Tensor, - flattened_documents: List[Document], - queries: List[str], - answers_per_seq: int, - top_k: Optional[int], - confidence_threshold: Optional[float], - query_ids: List[int], - document_ids: List[int], - no_answer: bool, - ) -> List[List[ExtractedAnswer]]: - """ - Reconstructs the nested structure that existed before flattening. Also computes a no answer probability. - This probability is different from most other implementations because it does not consider the no answer - logit introduced with SQuAD 2. Instead, it just computes the probability that the answer does not exist - in the top k or top p. - """ - flat_answers_without_queries = [] - for document_id, start_candidates_, end_candidates_, probabilities_ in zip( - document_ids, start, end, probabilities - ): - for start_, end_, probability in zip(start_candidates_, end_candidates_, probabilities_): - doc = flattened_documents[document_id] - # doc.content cannot be None, because those documents are filtered when preprocessing. - # However, mypy doesn't know that. - flat_answers_without_queries.append( - { - "data": doc.content[start_:end_], # type: ignore - "document": doc, - "probability": probability.item(), - "start": start_, - "end": end_, - "metadata": {}, - } - ) - i = 0 - nested_answers = [] - for query_id in range(query_ids[-1] + 1): - current_answers = [] - while i < len(flat_answers_without_queries) and query_ids[i // answers_per_seq] == query_id: - answer = flat_answers_without_queries[i] - answer["query"] = queries[query_id] - current_answers.append(ExtractedAnswer(**answer)) - i += 1 - current_answers = sorted(current_answers, key=lambda answer: answer.probability, reverse=True) - current_answers = current_answers[:top_k] - if no_answer: - no_answer_probability = math.prod(1 - answer.probability for answer in current_answers) - answer_ = ExtractedAnswer( - data=None, query=queries[query_id], metadata={}, document=None, probability=no_answer_probability - ) - current_answers.append(answer_) - current_answers = sorted(current_answers, key=lambda answer: answer.probability, reverse=True) - if confidence_threshold is not None: - current_answers = [answer for answer in current_answers if answer.probability >= confidence_threshold] - nested_answers.append(current_answers) - - return nested_answers - - @component.output_types(answers=List[ExtractedAnswer]) - def run( - self, - query: str, - documents: List[Document], - top_k: Optional[int] = None, - confidence_threshold: Optional[float] = None, - max_seq_length: Optional[int] = None, - stride: Optional[int] = None, - max_batch_size: Optional[int] = None, - answers_per_seq: Optional[int] = None, - no_answer: Optional[bool] = None, - ): - """ - Locates and extracts answers from the given Documents using the given query. - - :param query: Query string. - :param documents: List of Documents in which you want to search for an answer to the query. - :param top_k: The maximum number of answers to return. - An additional answer is returned if no_answer is set to True (default). - :param confidence_threshold: - :return: List of ExtractedAnswers sorted by (desc.) answer score. - :param confidence_threshold: Returns only answers with the probability score above this threshold. - :param max_seq_length: Maximum number of tokens. - If a sequence exceeds it, the sequence is split. - Default: 384 - :param stride: Number of tokens that overlap when sequence is split because it exceeds max_seq_length. - Default: 128 - :param max_batch_size: Maximum number of samples that are fed through the model at the same time. - :param answers_per_seq: Number of answer candidates to consider per sequence. - This is relevant when a Document was split into multiple sequences because of max_seq_length. - :param no_answer: Whether to return no answer scores. - """ - queries = [query] # Temporary solution until we have decided what batching should look like in v2 - nested_documents = [documents] - if self.model is None: - raise ComponentError("The component was not warmed up. Run 'warm_up()' before calling 'run()'.") - - top_k = top_k or self.top_k - confidence_threshold = confidence_threshold or self.confidence_threshold - max_seq_length = max_seq_length or self.max_seq_length - stride = stride or self.stride - max_batch_size = max_batch_size or self.max_batch_size - answers_per_seq = answers_per_seq or self.answers_per_seq or top_k or 20 - no_answer = no_answer if no_answer is not None else self.no_answer - - flattened_queries, flattened_documents, query_ids = self._flatten_documents(queries, nested_documents) - input_ids, attention_mask, sequence_ids, encodings, query_ids, document_ids = self._preprocess( - flattened_queries, flattened_documents, max_seq_length, query_ids, stride - ) - - num_batches = math.ceil(input_ids.shape[0] / max_batch_size) if max_batch_size else 1 - batch_size = max_batch_size or input_ids.shape[0] - - start_logits_list = [] - end_logits_list = [] - - for i in range(num_batches): - start_index = i * batch_size - end_index = start_index + batch_size - cur_input_ids = input_ids[start_index:end_index] - cur_attention_mask = attention_mask[start_index:end_index] - - output = self.model(input_ids=cur_input_ids, attention_mask=cur_attention_mask) - cur_start_logits = output.start_logits - cur_end_logits = output.end_logits - if num_batches != 1: - cur_start_logits = cur_start_logits.cpu() - cur_end_logits = cur_end_logits.cpu() - start_logits_list.append(cur_start_logits) - end_logits_list.append(cur_end_logits) - - start_logits = torch.cat(start_logits_list) - end_logits = torch.cat(end_logits_list) - - start, end, probabilities = self._postprocess( - start_logits, end_logits, sequence_ids, attention_mask, answers_per_seq, encodings - ) - - answers = self._nest_answers( - start, - end, - probabilities, - flattened_documents, - queries, - answers_per_seq, - top_k, - confidence_threshold, - query_ids, - document_ids, - no_answer, - ) - - return {"answers": answers[0]} # same temporary batching fix as above diff --git a/haystack/preview/components/retrievers/__init__.py b/haystack/preview/components/retrievers/__init__.py deleted file mode 100644 index 92e534d58b..0000000000 --- a/haystack/preview/components/retrievers/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from haystack.preview.components.retrievers.in_memory_bm25_retriever import InMemoryBM25Retriever -from haystack.preview.components.retrievers.in_memory_embedding_retriever import InMemoryEmbeddingRetriever - -__all__ = ["InMemoryBM25Retriever", "InMemoryEmbeddingRetriever"] diff --git a/haystack/preview/components/retrievers/in_memory_bm25_retriever.py b/haystack/preview/components/retrievers/in_memory_bm25_retriever.py deleted file mode 100644 index a3f7826123..0000000000 --- a/haystack/preview/components/retrievers/in_memory_bm25_retriever.py +++ /dev/null @@ -1,106 +0,0 @@ -from typing import Dict, List, Any, Optional - -from haystack.preview import component, Document, default_to_dict, default_from_dict, DeserializationError -from haystack.preview.document_stores import InMemoryDocumentStore, document_store - - -@component -class InMemoryBM25Retriever: - """ - Uses the BM25 algorithm to retrieve documents from the InMemoryDocumentStore. - - Needs to be connected to the InMemoryDocumentStore to run. - """ - - def __init__( - self, - document_store: InMemoryDocumentStore, - filters: Optional[Dict[str, Any]] = None, - top_k: int = 10, - scale_score: bool = False, - ): - """ - Create the InMemoryBM25Retriever component. - - :param document_store: An instance of InMemoryDocumentStore. - :param filters: A dictionary with filters to narrow down the search space. Defaults to `None`. - :param top_k: The maximum number of documents to retrieve. Defaults to `10`. - :param scale_score: Scales the BM25 score to a unit interval in the range of 0 to 1, where 1 means extremely relevant. If set to `False`, uses raw similarity scores. - Defaults to `False`. - - :raises ValueError: If the specified `top_k` is not > 0. - """ - if not isinstance(document_store, InMemoryDocumentStore): - raise ValueError("document_store must be an instance of InMemoryDocumentStore") - - self.document_store = document_store - - if top_k <= 0: - raise ValueError(f"top_k must be greater than 0. Currently, the top_k is {top_k}") - - self.filters = filters - self.top_k = top_k - self.scale_score = scale_score - - def _get_telemetry_data(self) -> Dict[str, Any]: - """ - Data that is sent to Posthog for usage analytics. - """ - return {"document_store": type(self.document_store).__name__} - - def to_dict(self) -> Dict[str, Any]: - """ - Serialize this component to a dictionary. - """ - docstore = self.document_store.to_dict() - return default_to_dict( - self, document_store=docstore, filters=self.filters, top_k=self.top_k, scale_score=self.scale_score - ) - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "InMemoryBM25Retriever": - """ - Deserialize this component from a dictionary. - """ - init_params = data.get("init_parameters", {}) - if "document_store" not in init_params: - raise DeserializationError("Missing 'document_store' in serialization data") - if "type" not in init_params["document_store"]: - raise DeserializationError("Missing 'type' in document store's serialization data") - if init_params["document_store"]["type"] not in document_store.registry: - raise DeserializationError(f"DocumentStore type '{init_params['document_store']['type']}' not found") - - docstore_class = document_store.registry[init_params["document_store"]["type"]] - docstore = docstore_class.from_dict(init_params["document_store"]) - data["init_parameters"]["document_store"] = docstore - return default_from_dict(cls, data) - - @component.output_types(documents=List[Document]) - def run( - self, - query: str, - filters: Optional[Dict[str, Any]] = None, - top_k: Optional[int] = None, - scale_score: Optional[bool] = None, - ): - """ - Run the InMemoryBM25Retriever on the given input data. - - :param query: The query string for the Retriever. - :param filters: A dictionary with filters to narrow down the search space. - :param top_k: The maximum number of documents to return. - :param scale_score: Scales the BM25 score to a unit interval in the range of 0 to 1, where 1 means extremely relevant. If set to `False`, uses raw similarity scores. - If not specified, the value provided at initialization is used. - :return: The retrieved documents. - - :raises ValueError: If the specified DocumentStore is not found or is not a InMemoryDocumentStore instance. - """ - if filters is None: - filters = self.filters - if top_k is None: - top_k = self.top_k - if scale_score is None: - scale_score = self.scale_score - - docs = self.document_store.bm25_retrieval(query=query, filters=filters, top_k=top_k, scale_score=scale_score) - return {"documents": docs} diff --git a/haystack/preview/components/retrievers/in_memory_embedding_retriever.py b/haystack/preview/components/retrievers/in_memory_embedding_retriever.py deleted file mode 100644 index dad86fdc58..0000000000 --- a/haystack/preview/components/retrievers/in_memory_embedding_retriever.py +++ /dev/null @@ -1,125 +0,0 @@ -from typing import Dict, List, Any, Optional - -from haystack.preview import component, Document, default_to_dict, default_from_dict, DeserializationError -from haystack.preview.document_stores import InMemoryDocumentStore, document_store - - -@component -class InMemoryEmbeddingRetriever: - """ - Uses a vector similarity metric to retrieve documents from the InMemoryDocumentStore. - - Needs to be connected to the InMemoryDocumentStore to run. - """ - - def __init__( - self, - document_store: InMemoryDocumentStore, - filters: Optional[Dict[str, Any]] = None, - top_k: int = 10, - scale_score: bool = False, - return_embedding: bool = False, - ): - """ - Create the InMemoryEmbeddingRetriever component. - - :param document_store: An instance of InMemoryDocumentStore. - :param filters: A dictionary with filters to narrow down the search space. Defaults to `None`. - :param top_k: The maximum number of documents to retrieve. Defaults to `10`. - :param scale_score: Scales the BM25 score to a unit interval in the range of 0 to 1, where 1 means extremely relevant. If set to `False`, uses raw similarity scores. - Defaults to `False`. - :param return_embedding: Whether to return the embedding of the retrieved Documents. Default is `False`. - - :raises ValueError: If the specified top_k is not > 0. - """ - if not isinstance(document_store, InMemoryDocumentStore): - raise ValueError("document_store must be an instance of InMemoryDocumentStore") - - self.document_store = document_store - - if top_k <= 0: - raise ValueError(f"top_k must be greater than 0. Currently, top_k is {top_k}") - - self.filters = filters - self.top_k = top_k - self.scale_score = scale_score - self.return_embedding = return_embedding - - def _get_telemetry_data(self) -> Dict[str, Any]: - """ - Data that is sent to Posthog for usage analytics. - """ - return {"document_store": type(self.document_store).__name__} - - def to_dict(self) -> Dict[str, Any]: - """ - Serialize this component to a dictionary. - """ - docstore = self.document_store.to_dict() - return default_to_dict( - self, - document_store=docstore, - filters=self.filters, - top_k=self.top_k, - scale_score=self.scale_score, - return_embedding=self.return_embedding, - ) - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "InMemoryEmbeddingRetriever": - """ - Deserialize this component from a dictionary. - """ - init_params = data.get("init_parameters", {}) - if "document_store" not in init_params: - raise DeserializationError("Missing 'document_store' in serialization data") - if "type" not in init_params["document_store"]: - raise DeserializationError("Missing 'type' in document store's serialization data") - if init_params["document_store"]["type"] not in document_store.registry: - raise DeserializationError(f"DocumentStore type '{init_params['document_store']['type']}' not found") - - docstore_class = document_store.registry[init_params["document_store"]["type"]] - docstore = docstore_class.from_dict(init_params["document_store"]) - data["init_parameters"]["document_store"] = docstore - return default_from_dict(cls, data) - - @component.output_types(documents=List[Document]) - def run( - self, - query_embedding: List[float], - filters: Optional[Dict[str, Any]] = None, - top_k: Optional[int] = None, - scale_score: Optional[bool] = None, - return_embedding: Optional[bool] = None, - ): - """ - Run the InMemoryEmbeddingRetriever on the given input data. - - :param query_embedding: Embedding of the query. - :param filters: A dictionary with filters to narrow down the search space. - :param top_k: The maximum number of documents to return. - :param scale_score: Scales the similarity score to a unit interval in the range of 0 to 1, where 1 means extremely relevant. If set to `False`, uses raw similarity scores. - If not specified, the value provided at initialization is used. - :param return_embedding: Whether to return the embedding of the retrieved Documents. - :return: The retrieved documents. - - :raises ValueError: If the specified DocumentStore is not found or is not an InMemoryDocumentStore instance. - """ - if filters is None: - filters = self.filters - if top_k is None: - top_k = self.top_k - if scale_score is None: - scale_score = self.scale_score - if return_embedding is None: - return_embedding = self.return_embedding - - docs = self.document_store.embedding_retrieval( - query_embedding=query_embedding, - filters=filters, - top_k=top_k, - scale_score=scale_score, - return_embedding=return_embedding, - ) - - return {"documents": docs} diff --git a/haystack/preview/components/routers/__init__.py b/haystack/preview/components/routers/__init__.py deleted file mode 100644 index 2da95625fc..0000000000 --- a/haystack/preview/components/routers/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from haystack.preview.components.routers.document_joiner import DocumentJoiner -from haystack.preview.components.routers.file_type_router import FileTypeRouter -from haystack.preview.components.routers.metadata_router import MetadataRouter -from haystack.preview.components.routers.conditional_router import ConditionalRouter -from haystack.preview.components.routers.text_language_router import TextLanguageRouter - -__all__ = ["DocumentJoiner", "FileTypeRouter", "MetadataRouter", "TextLanguageRouter", "ConditionalRouter"] diff --git a/haystack/preview/components/routers/conditional_router.py b/haystack/preview/components/routers/conditional_router.py deleted file mode 100644 index af96d96ff3..0000000000 --- a/haystack/preview/components/routers/conditional_router.py +++ /dev/null @@ -1,347 +0,0 @@ -import importlib -import inspect -import logging -import sys -from typing import List, Dict, Any, Set, get_origin - -from jinja2 import meta, Environment, TemplateSyntaxError -from jinja2.nativetypes import NativeEnvironment - -from haystack.preview import component, default_from_dict, default_to_dict, DeserializationError - -logger = logging.getLogger(__name__) - - -class NoRouteSelectedException(Exception): - """Exception raised when no route is selected in ConditionalRouter.""" - - -class RouteConditionException(Exception): - """Exception raised when there is an error parsing or evaluating the condition expression in ConditionalRouter.""" - - -def serialize_type(target: Any) -> str: - """ - Serializes a type or an instance to its string representation, including the module name. - - This function handles types, instances of types, and special typing objects. - It assumes that non-typing objects will have a '__name__' attribute and raises - an error if a type cannot be serialized. - - :param target: The object to serialize, can be an instance or a type. - :type target: Any - :return: The string representation of the type. - :raises ValueError: If the type cannot be serialized. - """ - # If the target is a string and contains a dot, treat it as an already serialized type - if isinstance(target, str) and "." in target: - return target - - # Determine if the target is a type or an instance of a typing object - is_type_or_typing = isinstance(target, type) or bool(get_origin(target)) - type_obj = target if is_type_or_typing else type(target) - module = inspect.getmodule(type_obj) - type_obj_repr = repr(type_obj) - - if type_obj_repr.startswith("typing."): - # e.g., typing.List[int] -> List[int], we'll add the module below - type_name = type_obj_repr.split(".", 1)[1] - elif hasattr(type_obj, "__name__"): - type_name = type_obj.__name__ - else: - # If type cannot be serialized, raise an error - raise ValueError(f"Could not serialize type: {type_obj_repr}") - - # Construct the full path with module name if available - if module and hasattr(module, "__name__"): - if module.__name__ == "builtins": - # omit the module name for builtins, it just clutters the output - # e.g. instead of 'builtins.str', we'll just return 'str' - full_path = type_name - else: - full_path = f"{module.__name__}.{type_name}" - else: - full_path = type_name - - return full_path - - -def deserialize_type(type_str: str) -> Any: - """ - Deserializes a type given its full import path as a string, including nested generic types. - - This function will dynamically import the module if it's not already imported - and then retrieve the type object from it. It also handles nested generic types like 'typing.List[typing.Dict[int, str]]'. - - :param type_str: The string representation of the type's full import path. - :return: The deserialized type object. - :raises DeserializationError: If the type cannot be deserialized due to missing module or type. - """ - - def parse_generic_args(args_str): - args = [] - bracket_count = 0 - current_arg = "" - - for char in args_str: - if char == "[": - bracket_count += 1 - elif char == "]": - bracket_count -= 1 - - if char == "," and bracket_count == 0: - args.append(current_arg.strip()) - current_arg = "" - else: - current_arg += char - - if current_arg: - args.append(current_arg.strip()) - - return args - - if "[" in type_str and type_str.endswith("]"): - # Handle generics - main_type_str, generics_str = type_str.split("[", 1) - generics_str = generics_str[:-1] - - main_type = deserialize_type(main_type_str) - generic_args = tuple(deserialize_type(arg) for arg in parse_generic_args(generics_str)) - - # Reconstruct - return main_type[generic_args] - - else: - # Handle non-generics - parts = type_str.split(".") - module_name = ".".join(parts[:-1]) or "builtins" - type_name = parts[-1] - - module = sys.modules.get(module_name) - if not module: - try: - module = importlib.import_module(module_name) - except ImportError as e: - raise DeserializationError(f"Could not import the module: {module_name}") from e - - deserialized_type = getattr(module, type_name, None) - if not deserialized_type: - raise DeserializationError(f"Could not locate the type: {type_name} in the module: {module_name}") - - return deserialized_type - - -@component -class ConditionalRouter: - """ - ConditionalRouter in Haystack 2.x pipelines is designed to manage data routing based on specific conditions. - This is achieved by defining a list named 'routes'. Each element in this list is a dictionary representing a - single route. - - A route dictionary comprises four key elements: - - 'condition': A Jinja2 string expression that determines if the route is selected. - - 'output': A Jinja2 expression defining the route's output value. - - 'output_type': The type of the output data (e.g., str, List[int]). - - 'output_name': The name under which the `output` value of the route is published. This name is used to connect - the router to other components in the pipeline. - - Here's an example: - - ```python - from haystack.preview.components.routers import ConditionalRouter - - routes = [ - { - "condition": "{{streams|length > 2}}", - "output": "{{streams}}", - "output_name": "enough_streams", - "output_type": List[int], - }, - { - "condition": "{{streams|length <= 2}}", - "output": "{{streams}}", - "output_name": "insufficient_streams", - "output_type": List[int], - }, - ] - router = ConditionalRouter(routes) - # When 'streams' has more than 2 items, 'enough_streams' output will activate, emitting the list [1, 2, 3] - kwargs = {"streams": [1, 2, 3], "query": "Haystack"} - result = router.run(**kwargs) - assert result == {"enough_streams": [1, 2, 3]} - ``` - - In this example, we configure two routes. The first route sends the 'streams' value to 'enough_streams' if the - stream count exceeds two. Conversely, the second route directs 'streams' to 'insufficient_streams' when there - are two or fewer streams. - - In the pipeline setup, the router is connected to other components using the output names. For example, the - 'enough_streams' output might be connected to another component that processes the streams, while the - 'insufficient_streams' output might be connected to a component that fetches more streams, and so on. - - Here is a pseudocode example of a pipeline that uses the ConditionalRouter and routes fetched ByteStreams to - different components depending on the number of streams fetched: - - ``` - from typing import List - from haystack import Pipeline - from haystack.preview.dataclasses import ByteStream - from haystack.preview.components.routers import ConditionalRouter - - routes = [ - { - "condition": "{{streams|length > 2}}", - "output": "{{streams}}", - "output_name": "enough_streams", - "output_type": List[ByteStream], - }, - { - "condition": "{{streams|length <= 2}}", - "output": "{{streams}}", - "output_name": "insufficient_streams", - "output_type": List[ByteStream], - }, - ] - - pipe = Pipeline() - pipe.add_component("router", router) - ... - pipe.connect("router.enough_streams", "some_component_a.streams") - pipe.connect("router.insufficient_streams", "some_component_b.streams_or_some_other_input") - ... - ``` - """ - - def __init__(self, routes: List[Dict]): - """ - Initializes the ConditionalRouter with a list of routes detailing the conditions for routing. - - :param routes: A list of dictionaries, each defining a route with a boolean condition expression - ('condition'), an output value ('output'), the output type ('output_type') and - ('output_name') that defines the output name for the variable defined in 'output'. - """ - self._validate_routes(routes) - self.routes: List[dict] = routes - - # Create a Jinja native environment to inspect variables in the condition templates - env = NativeEnvironment() - - # Inspect the routes to determine input and output types. - input_types: Set[str] = set() # let's just store the name, type will always be Any - output_types: Dict[str, str] = {} - - for route in routes: - # extract inputs - route_input_names = self._extract_variables(env, [route["output"], route["condition"]]) - input_types.update(route_input_names) - - # extract outputs - output_types.update({route["output_name"]: route["output_type"]}) - - component.set_input_types(self, **{var: Any for var in input_types}) - component.set_output_types(self, **output_types) - - def to_dict(self) -> Dict[str, Any]: - for route in self.routes: - # output_type needs to be serialized to a string - route["output_type"] = serialize_type(route["output_type"]) - - return default_to_dict(self, routes=self.routes) - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ConditionalRouter": - init_params = data.get("init_parameters", {}) - routes = init_params.get("routes") - for route in routes: - # output_type needs to be deserialized from a string to a type - route["output_type"] = deserialize_type(route["output_type"]) - return default_from_dict(cls, data) - - def run(self, **kwargs): - """ - Executes the routing logic by evaluating the specified boolean condition expressions - for each route in the order they are listed. The method directs the flow - of data to the output specified in the first route, whose expression - evaluates to True. If no route's expression evaluates to True, an exception - is raised. - - :param kwargs: A dictionary containing the pipeline variables, which should - include all variables used in the "condition" templates. - - :return: A dictionary containing the output and the corresponding result, - based on the first route whose expression evaluates to True. - - :raises NoRouteSelectedException: If no route's expression evaluates to True. - """ - # Create a Jinja native environment to evaluate the condition templates as Python expressions - env = NativeEnvironment() - - for route in self.routes: - try: - t = env.from_string(route["condition"]) - if t.render(**kwargs): - # We now evaluate the `output` expression to determine the route output - t_output = env.from_string(route["output"]) - output = t_output.render(**kwargs) - # and return the output as a dictionary under the output_name key - return {route["output_name"]: output} - except Exception as e: - raise RouteConditionException(f"Error evaluating condition for route '{route}': {e}") from e - - raise NoRouteSelectedException(f"No route fired. Routes: {self.routes}") - - def _validate_routes(self, routes: List[Dict]): - """ - Validates a list of routes. - - :param routes: A list of routes. - :type routes: List[Dict] - """ - env = NativeEnvironment() - for route in routes: - try: - keys = set(route.keys()) - except AttributeError: - raise ValueError(f"Route must be a dictionary, got: {route}") - - mandatory_fields = {"condition", "output", "output_type", "output_name"} - has_all_mandatory_fields = mandatory_fields.issubset(keys) - if not has_all_mandatory_fields: - raise ValueError( - f"Route must contain 'condition', 'output', 'output_type' and 'output_name' fields: {route}" - ) - for field in ["condition", "output"]: - if not self._validate_template(env, route[field]): - raise ValueError(f"Invalid template for field '{field}': {route[field]}") - - def _extract_variables(self, env: NativeEnvironment, templates: List[str]) -> Set[str]: - """ - Extracts all variables from a list of Jinja template strings. - - :param env: A Jinja environment. - :type env: Environment - :param templates: A list of Jinja template strings. - :type templates: List[str] - :return: A set of variable names. - """ - variables = set() - for template in templates: - ast = env.parse(template) - variables.update(meta.find_undeclared_variables(ast)) - return variables - - def _validate_template(self, env: Environment, template_text: str): - """ - Validates a template string by parsing it with Jinja. - - :param env: A Jinja environment. - :type env: Environment - :param template_text: A Jinja template string. - :type template_text: str - :return: True if the template is valid, False otherwise. - """ - try: - env.parse(template_text) - return True - except TemplateSyntaxError: - return False diff --git a/haystack/preview/components/routers/document_joiner.py b/haystack/preview/components/routers/document_joiner.py deleted file mode 100644 index 96b9b989b2..0000000000 --- a/haystack/preview/components/routers/document_joiner.py +++ /dev/null @@ -1,153 +0,0 @@ -import itertools -import logging -from collections import defaultdict -from math import inf -from typing import List, Optional -from canals.component.types import Variadic - -from haystack.preview import component, Document - - -logger = logging.getLogger(__name__) - - -@component -class DocumentJoiner: - """ - A component that joins input lists of Documents from multiple connections and outputs them as one list. - - The component allows multiple join modes: - * concatenate: Combine Documents from multiple components. Discards duplicate Documents. - Documents get their scores from the last component in the pipeline that assigns scores. - This join mode doesn't influence Document scores. - * merge: Merge scores of duplicate Documents coming from multiple components. - Optionally, you can assign a weight to the scores and set the top_k limit for this join mode. - You can also use this join mode to rerank retrieved Documents. - * reciprocal_rank_fusion: Combine Documents into a single list based on their ranking received from multiple components. - - Example usage in a hybrid retrieval pipeline: - ```python - document_store = InMemoryDocumentStore() - p = Pipeline() - p.add_component(instance=InMemoryBM25Retriever(document_store=document_store), name="bm25_retriever") - p.add_component( - instance=SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"), - name="text_embedder", - ) - p.add_component(instance=InMemoryEmbeddingRetriever(document_store=document_store), name="embedding_retriever") - p.add_component(instance=DocumentJoiner(), name="joiner") - p.connect("bm25_retriever", "joiner") - p.connect("embedding_retriever", "joiner") - p.connect("text_embedder", "embedding_retriever") - query = "What is the capital of France?" - p.run(data={"bm25_retriever": {"query": query}, - "text_embedder": {"text": query}}) - ``` - """ - - def __init__( - self, - join_mode: str = "concatenate", - weights: Optional[List[float]] = None, - top_k: Optional[int] = None, - sort_by_score: bool = True, - ): - """ - Initialize the DocumentJoiner. - - :param join_mode: Specifies the join mode to use. Available modes: `concatenate` to combine Documents from multiple Retrievers, `merge` to aggregate the scores of - individual Documents, `reciprocal_rank_fusion` to apply rank-based scoring. - :param weights: A component-wise list (the length of the list must be equal to the number of input components) of weights for - adjusting Document scores when using the `merge` join_mode. By default, equal weight is given - to each Retriever score. This param is not compatible with the `concatenate` join_mode. - :param top_k: The maximum number of Documents to be returned as output. By default, returns all Documents. - :param sort_by_score: Whether the output list of Documents should be sorted by Document scores in descending order. - By default, the output is sorted. - Documents without score are handled as if their score was -infinity. - """ - if join_mode not in ["concatenate", "merge", "reciprocal_rank_fusion"]: - raise ValueError(f"DocumentJoiner component does not support '{join_mode}' join_mode.") - self.join_mode = join_mode - self.weights = [float(i) / sum(weights) for i in weights] if weights else None - self.top_k = top_k - self.sort_by_score = sort_by_score - - @component.output_types(documents=List[Document]) - def run(self, documents: Variadic[List[Document]]): - """ - Run the DocumentJoiner. This method joins the input lists of Documents into one output list based on the join_mode specified during initialization. - - :param documents: An arbitrary number of lists of Documents to join. - """ - output_documents = [] - if self.join_mode == "concatenate": - output_documents = self._concatenate(documents) - elif self.join_mode == "merge": - output_documents = self._merge(documents) - elif self.join_mode == "reciprocal_rank_fusion": - output_documents = self._reciprocal_rank_fusion(documents) - - if self.sort_by_score: - output_documents = sorted( - output_documents, key=lambda doc: doc.score if doc.score is not None else -inf, reverse=True - ) - if any(doc.score is None for doc in output_documents): - logger.info( - "Some of the Documents DocumentJoiner got have score=None. It was configured to sort Documents by " - "score, so those with score=None were sorted as if they had a score of -infinity." - ) - - if self.top_k: - output_documents = output_documents[: self.top_k] - return {"documents": output_documents} - - def _concatenate(self, document_lists): - """ - Concatenate multiple lists of Documents and return only the Document with the highest score for duplicate Documents. - """ - output = [] - docs_per_id = defaultdict(list) - for doc in itertools.chain.from_iterable(document_lists): - docs_per_id[doc.id].append(doc) - for docs in docs_per_id.values(): - doc_with_best_score = max(docs, key=lambda doc: doc.score if doc.score else -inf) - output.append(doc_with_best_score) - return output - - def _merge(self, document_lists): - """ - Merge multiple lists of Documents and calculate a weighted sum of the scores of duplicate Documents. - """ - scores_map = defaultdict(int) - documents_map = {} - weights = self.weights if self.weights else [1 / len(document_lists)] * len(document_lists) - - for documents, weight in zip(document_lists, weights): - for doc in documents: - scores_map[doc.id] += (doc.score if doc.score else 0) * weight - documents_map[doc.id] = doc - - for doc in documents_map.values(): - doc.score = scores_map[doc.id] - - return documents_map.values() - - def _reciprocal_rank_fusion(self, document_lists): - """ - Merge multiple lists of Documents and assign scores based on reciprocal rank fusion. - The constant k is set to 61 (60 was suggested by the original paper, - plus 1 as python lists are 0-based and the paper used 1-based ranking). - """ - k = 61 - - scores_map = defaultdict(int) - documents_map = {} - for documents in document_lists: - for rank, doc in enumerate(documents): - scores_map[doc.id] += 1 / (k + rank) - documents_map[doc.id] = doc - - for doc in documents_map.values(): - doc.score = scores_map[doc.id] - - return documents_map.values() diff --git a/haystack/preview/components/routers/file_type_router.py b/haystack/preview/components/routers/file_type_router.py deleted file mode 100644 index e644129706..0000000000 --- a/haystack/preview/components/routers/file_type_router.py +++ /dev/null @@ -1,87 +0,0 @@ -import logging -import mimetypes -from collections import defaultdict -from pathlib import Path -from typing import List, Union, Optional, Dict - -from haystack.preview import component -from haystack.preview.dataclasses import ByteStream - -logger = logging.getLogger(__name__) - - -@component -class FileTypeRouter: - """ - FileTypeRouter takes a list of data sources (file paths or byte streams) and groups them by their corresponding - MIME types. For file paths, MIME types are inferred from their extensions, while for byte streams, MIME types - are determined from the provided metadata. - - The set of MIME types to consider is specified during the initialization of the component. - - This component is invaluable when categorizing a large collection of files or data streams by their MIME - types and routing them to different components for further processing. - """ - - def __init__(self, mime_types: List[str]): - """ - Initialize the FileTypeRouter. - - :param mime_types: A list of file mime types to consider when routing - files (e.g. ["text/plain", "audio/x-wav", "image/jpeg"]). - """ - if not mime_types: - raise ValueError("The list of mime types cannot be empty.") - - for mime_type in mime_types: - if not self.is_valid_mime_type_format(mime_type): - raise ValueError( - f"Unknown mime type: '{mime_type}'. Ensure you passed a list of strings in the 'mime_types' parameter" - ) - - component.set_output_types(self, unclassified=List[Path], **{mime_type: List[Path] for mime_type in mime_types}) - self.mime_types = mime_types - - def run(self, sources: List[Union[str, Path, ByteStream]]) -> Dict[str, List[Union[ByteStream, Path]]]: - """ - Categorizes the provided data sources by their MIME types. - - :param sources: A list of file paths or byte streams to categorize. - :return: A dictionary where keys are MIME types and values are lists of data sources. - """ - - mime_types = defaultdict(list) - for source in sources: - if isinstance(source, str): - source = Path(source) - - if isinstance(source, Path): - mime_type = self.get_mime_type(source) - elif isinstance(source, ByteStream): - mime_type = source.metadata.get("content_type") - else: - raise ValueError(f"Unsupported data source type: {type(source)}") - - if mime_type in self.mime_types: - mime_types[mime_type].append(source) - else: - mime_types["unclassified"].append(source) - - return mime_types - - def get_mime_type(self, path: Path) -> Optional[str]: - """ - Get the MIME type of the provided file path. - - :param path: The file path to get the MIME type for. - :return: The MIME type of the provided file path, or None if the MIME type cannot be determined. - """ - return mimetypes.guess_type(path.as_posix())[0] - - def is_valid_mime_type_format(self, mime_type: str) -> bool: - """ - Check if the provided MIME type is in valid format - :param mime_type: The MIME type to check. - :return: True if the provided MIME type is a valid MIME type format, False otherwise. - """ - return mime_type in mimetypes.types_map.values() diff --git a/haystack/preview/components/routers/metadata_router.py b/haystack/preview/components/routers/metadata_router.py deleted file mode 100644 index f83b1e5542..0000000000 --- a/haystack/preview/components/routers/metadata_router.py +++ /dev/null @@ -1,81 +0,0 @@ -from typing import Dict, List - -from haystack.preview import component, Document -from haystack.preview.utils.filters import document_matches_filter, convert - - -@component -class MetadataRouter: - """ - A component that routes documents to different connections based on the content of their fields. - """ - - def __init__(self, rules: Dict[str, Dict]): - """ - Initialize the MetadataRouter. - - :param rules: A dictionary of rules that specify which edge to route a document to based on its metadata. - The keys of the dictionary are the names of the output connections, and the values are dictionaries that - follow the format of filtering expressions in Haystack. For example: - ```python - { - "edge_1": { - "operator": "AND", - "conditions": [ - {"field": "meta.created_at", "operator": ">=", "value": "2023-01-01"}, - {"field": "meta.created_at", "operator": "<", "value": "2023-04-01"}, - ], - }, - "edge_2": { - "operator": "AND", - "conditions": [ - {"field": "meta.created_at", "operator": ">=", "value": "2023-04-01"}, - {"field": "meta.created_at", "operator": "<", "value": "2023-07-01"}, - ], - }, - "edge_3": { - "operator": "AND", - "conditions": [ - {"field": "meta.created_at", "operator": ">=", "value": "2023-07-01"}, - {"field": "meta.created_at", "operator": "<", "value": "2023-10-01"}, - ], - }, - "edge_4": { - "operator": "AND", - "conditions": [ - {"field": "meta.created_at", "operator": ">=", "value": "2023-10-01"}, - {"field": "meta.created_at", "operator": "<", "value": "2024-01-01"}, - ], - }, - } - ``` - """ - self.rules = rules - component.set_output_types(self, unmatched=List[Document], **{edge: List[Document] for edge in rules}) - - def run(self, documents: List[Document]): - """ - Run the MetadataRouter. This method routes the documents to different edges based on their fields content and - the rules specified during initialization. If a document does not match any of the rules, it is routed to - a connection named "unmatched". - - :param documents: A list of documents to route to different edges. - """ - unmatched_documents = [] - output: Dict[str, List[Document]] = {edge: [] for edge in self.rules} - - for document in documents: - cur_document_matched = False - for edge, rule in self.rules.items(): - if "operator" not in rule: - # Must be a legacy filter, convert it - rule = convert(rule) - if document_matches_filter(rule, document): - output[edge].append(document) - cur_document_matched = True - - if not cur_document_matched: - unmatched_documents.append(document) - - output["unmatched"] = unmatched_documents - return output diff --git a/haystack/preview/components/routers/text_language_router.py b/haystack/preview/components/routers/text_language_router.py deleted file mode 100644 index 4e8ffd0167..0000000000 --- a/haystack/preview/components/routers/text_language_router.py +++ /dev/null @@ -1,73 +0,0 @@ -import logging -from typing import List, Dict, Optional - -from haystack.preview import component -from haystack.preview.lazy_imports import LazyImport - -logger = logging.getLogger(__name__) - -with LazyImport("Run 'pip install langdetect'") as langdetect_import: - import langdetect - - -@component -class TextLanguageRouter: - """ - Routes a text input onto one of different output connections depending on its language. - This is useful for routing queries to different models in a pipeline depending on their language. - The set of supported languages can be specified. - For routing Documents based on their language use the related DocumentLanguageClassifier component to first - classify the documents and then the MetaDataRouter to route them. - - Example usage in a retrieval pipeline that passes only English language queries to the retriever: - - ```python - document_store = InMemoryDocumentStore() - p = Pipeline() - p.add_component(instance=TextLanguageRouter(), name="text_language_router") - p.add_component(instance=InMemoryBM25Retriever(document_store=document_store), name="retriever") - p.connect("text_language_router.en", "retriever.query") - p.run({"text_language_router": {"text": "What's your query?"}}) - ``` - """ - - def __init__(self, languages: Optional[List[str]] = None): - """ - :param languages: A list of languages in ISO code, each corresponding to a different output connection (see [langdetect` documentation](https://github.com/Mimino666/langdetect#languages)). By default, only ["en"] is supported and texts of any other language are routed to "unmatched". - """ - langdetect_import.check() - if not languages: - languages = ["en"] - self.languages = languages - component.set_output_types(self, unmatched=str, **{language: str for language in languages}) - - def run(self, text: str) -> Dict[str, str]: - """ - Run the TextLanguageRouter. This method routes the text one of different edges based on its language. - If the text does not match any of the languages specified at initialization, it is routed to - a connection named "unmatched". - - :param text: A str to route to one of different edges. - """ - if not isinstance(text, str): - raise TypeError( - "TextLanguageRouter expects a str as input. In case you want to classify a document, please use the DocumentLanguageClassifier and MetaDataRouter." - ) - - output: Dict[str, str] = {} - - detected_language = self.detect_language(text) - if detected_language in self.languages: - output[detected_language] = text - else: - output["unmatched"] = text - - return output - - def detect_language(self, text: str) -> Optional[str]: - try: - language = langdetect.detect(text) - except langdetect.LangDetectException: - logger.warning("Langdetect cannot detect the language of text: %s", text) - language = None - return language diff --git a/haystack/preview/components/samplers/__init__.py b/haystack/preview/components/samplers/__init__.py deleted file mode 100644 index cab0e878e8..0000000000 --- a/haystack/preview/components/samplers/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from haystack.preview.components.samplers.top_p import TopPSampler - -__all__ = ["TopPSampler"] diff --git a/haystack/preview/components/samplers/top_p.py b/haystack/preview/components/samplers/top_p.py deleted file mode 100644 index c3740cbbaa..0000000000 --- a/haystack/preview/components/samplers/top_p.py +++ /dev/null @@ -1,127 +0,0 @@ -import logging -from typing import List, Optional - -from haystack.preview import ComponentError, Document, component -from haystack.preview.lazy_imports import LazyImport - -logger = logging.getLogger(__name__) - - -with LazyImport(message="Run 'pip install \"torch>=1.13\"'") as torch_import: - import torch - - -@component -class TopPSampler: - """ - Filters documents using top-p (nucleus) sampling based on their similarity scores' cumulative probability. - - Usage example: - - ```python - from haystack.preview import Document - from haystack.preview.components.samplers import TopPSampler - - sampler = TopPSampler(top_p=0.95) - docs = [ - Document(text="Berlin", metadata={"similarity_score": -10.6}), - Document(text="Belgrade", metadata={"similarity_score": -8.9}), - Document(text="Sarajevo", metadata={"similarity_score": -4.6}), - ] - output = sampler.run(documents=docs) - docs = output["documents"] - assert len(docs) == 1 - assert docs[0].content == "Sarajevo" - ``` - """ - - def __init__(self, top_p: float = 1.0, score_field: Optional[str] = None): - """ - Creates an instance of TopPSampler. - - :param top_p: Cumulative probability threshold (usually between 0.9 and 0.99). - :param score_field: Field name in a document's metadata containing the scores. Defaults to the Document score - if not provided. - """ - torch_import.check() - - self.top_p = top_p - self.score_field = score_field - - @component.output_types(documents=List[Document]) - def run(self, documents: List[Document], top_p: Optional[float] = None): - """ - Filter documents based on their similarity scores using top-p sampling. - - :param documents: List of Documents to filter. - :param top_p: Cumulative probability threshold. Defaults to the value set during initialization or 1.0 - if not set. - :return: List of filtered Documents. - """ - if not documents: - return {"documents": []} - - top_p = top_p or self.top_p or 1.0 # default to 1.0 if both are None - - if not 0 <= top_p <= 1: - raise ComponentError(f"top_p must be between 0 and 1. Got {top_p}.") - - similarity_scores = torch.tensor(self._collect_scores(documents), dtype=torch.float32) - - # Apply softmax normalization to the similarity scores - probs = torch.nn.functional.softmax(similarity_scores, dim=-1) - - # Sort the probabilities and calculate their cumulative sum - sorted_probs, sorted_indices = torch.sort(probs, descending=True) - cumulative_probs = torch.cumsum(sorted_probs, dim=-1) - - # Check if the cumulative probabilities are close to top_p with a 1e-6 tolerance - close_to_top_p = torch.isclose(cumulative_probs, torch.tensor(top_p, device=cumulative_probs.device), atol=1e-6) - - # Combine the close_to_top_p with original condition using logical OR - condition = (cumulative_probs <= top_p) | close_to_top_p - - # Find the indices with cumulative probabilities that exceed top_p - top_p_indices = torch.where(torch.BoolTensor(condition))[0] - - # Map the selected indices back to their original indices - original_indices = sorted_indices[top_p_indices] - selected_docs = [documents[i.item()] for i in original_indices] - - # If low p resulted in no documents being selected, then - # return at least one document - if not selected_docs: - logger.warning( - "Top-p sampling with p=%s resulted in no documents being selected. " - "Returning the document with the highest similarity score.", - top_p, - ) - highest_prob_indices = torch.argsort(probs, descending=True) - selected_docs = [documents[int(highest_prob_indices[0].item())]] - - return {"documents": selected_docs} - - def _collect_scores(self, documents: List[Document]) -> List[float]: - """ - Collect the scores from the documents' metadata. - :param documents: List of Documents. - :return: List of scores. - """ - if self.score_field: - missing_scores_docs = [d for d in documents if self.score_field not in d.meta] - if missing_scores_docs: - missing_scores_docs_ids = [d.id for d in missing_scores_docs if d.id] - raise ComponentError( - f"Score field '{self.score_field}' not found in metadata of documents " - f"with IDs: {missing_scores_docs_ids}." - f"Make sure that all documents have a score field '{self.score_field}' in their metadata." - ) - return [d.meta[self.score_field] for d in documents] - else: - missing_scores_docs = [d for d in documents if d.score is None] - if missing_scores_docs: - missing_scores_docs_ids = [d.id for d in missing_scores_docs if d.id] - raise ComponentError( - f"Ensure all documents have a valid score value. These docs {missing_scores_docs_ids} don't." - ) - return [d.score for d in documents] # type: ignore ## because Document score is Optional diff --git a/haystack/preview/components/websearch/__init__.py b/haystack/preview/components/websearch/__init__.py deleted file mode 100644 index ce20e77857..0000000000 --- a/haystack/preview/components/websearch/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from haystack.preview.components.websearch.serper_dev import SerperDevWebSearch -from haystack.preview.components.websearch.searchapi import SearchApiWebSearch - -__all__ = ["SerperDevWebSearch", "SearchApiWebSearch"] diff --git a/haystack/preview/components/websearch/searchapi.py b/haystack/preview/components/websearch/searchapi.py deleted file mode 100644 index fafe8bd95a..0000000000 --- a/haystack/preview/components/websearch/searchapi.py +++ /dev/null @@ -1,140 +0,0 @@ -import json -import os -import logging -from typing import Dict, List, Optional, Any - -import requests - -from haystack.preview import Document, component, default_to_dict, ComponentError - -logger = logging.getLogger(__name__) - - -SEARCHAPI_BASE_URL = "https://www.searchapi.io/api/v1/search" - - -class SearchApiError(ComponentError): - ... - - -@component -class SearchApiWebSearch: - """ - Search engine using SearchApi API. Given a query, it returns a list of URLs that are the most relevant. - - See the [SearchApi website](https://www.searchapi.io/) for more details. - """ - - def __init__( - self, - api_key: Optional[str] = None, - top_k: Optional[int] = 10, - allowed_domains: Optional[List[str]] = None, - search_params: Optional[Dict[str, Any]] = None, - ): - """ - :param api_key: API key for the SearchApi API. It can be - explicitly provided or automatically read from the - environment variable SEARCHAPI_API_KEY (recommended). - :param top_k: Number of documents to return. - :param allowed_domains: List of domains to limit the search to. - :param search_params: Additional parameters passed to the SearchApi API. - For example, you can set 'num' to 100 to increase the number of search results. - See the [SearchApi website](https://www.searchapi.io/) for more details. - """ - if api_key is None: - try: - api_key = os.environ["SEARCHAPI_API_KEY"] - except KeyError as e: - raise ValueError( - "SearchApiWebSearch expects an API key. " - "Set the SEARCHAPI_API_KEY environment variable (recommended) or pass it explicitly." - ) from e - self.api_key = api_key - self.top_k = top_k - self.allowed_domains = allowed_domains - self.search_params = search_params or {} - - def to_dict(self) -> Dict[str, Any]: - """ - Serialize this component to a dictionary. - """ - return default_to_dict( - self, top_k=self.top_k, allowed_domains=self.allowed_domains, search_params=self.search_params - ) - - @component.output_types(documents=List[Document], links=List[str]) - def run(self, query: str): - """ - Search the SearchApi API for the given query and return the results as a list of Documents and a list of links. - - :param query: Query string. - """ - query_prepend = "OR ".join(f"site:{domain} " for domain in self.allowed_domains) if self.allowed_domains else "" - - payload = json.dumps({"q": query_prepend + " " + query, **self.search_params}) - headers = {"Authorization": f"Bearer {self.api_key}", "X-SearchApi-Source": "Haystack"} - - try: - response = requests.get(SEARCHAPI_BASE_URL, headers=headers, params=payload, timeout=90) - response.raise_for_status() # Will raise an HTTPError for bad responses - except requests.Timeout: - raise TimeoutError(f"Request to {self.__class__.__name__} timed out.") - - except requests.RequestException as e: - raise SearchApiError(f"An error occurred while querying {self.__class__.__name__}. Error: {e}") from e - - # Request succeeded - json_result = response.json() - - # organic results are the main results from the search engine - organic_results = [] - if "organic_results" in json_result: - for result in json_result["organic_results"]: - organic_results.append( - Document.from_dict({"title": result["title"], "content": result["snippet"], "link": result["link"]}) - ) - - # answer box has a direct answer to the query - answer_box = [] - if "answer_box" in json_result: - answer_box = [ - Document.from_dict( - { - "title": json_result["answer_box"].get("title", ""), - "content": json_result["answer_box"].get("answer", ""), - "link": json_result["answer_box"].get("link", ""), - } - ) - ] - - knowledge_graph = [] - if "knowledge_graph" in json_result: - knowledge_graph = [ - Document.from_dict( - { - "title": json_result["knowledge_graph"].get("title", ""), - "content": json_result["knowledge_graph"].get("description", ""), - } - ) - ] - - related_questions = [] - if "related_questions" in json_result: - for result in json_result["related_questions"]: - related_questions.append( - Document.from_dict( - { - "title": result["question"], - "content": result["answer"] if result.get("answer") else result.get("answer_highlight", ""), - "link": result.get("source", {}).get("link", ""), - } - ) - ) - - documents = answer_box + knowledge_graph + organic_results + related_questions - - links = [result["link"] for result in json_result["organic_results"]] - - logger.debug("SearchApi returned %s documents for the query '%s'", len(documents), query) - return {"documents": documents[: self.top_k], "links": links[: self.top_k]} diff --git a/haystack/preview/components/websearch/serper_dev.py b/haystack/preview/components/websearch/serper_dev.py deleted file mode 100644 index 8b98d3ebc9..0000000000 --- a/haystack/preview/components/websearch/serper_dev.py +++ /dev/null @@ -1,140 +0,0 @@ -import json -import os -import logging -from typing import Dict, List, Optional, Any - -import requests - -from haystack.preview import Document, component, default_to_dict, ComponentError - -logger = logging.getLogger(__name__) - - -SERPERDEV_BASE_URL = "https://google.serper.dev/search" - - -class SerperDevError(ComponentError): - ... - - -@component -class SerperDevWebSearch: - """ - Search engine using SerperDev API. Given a query, it returns a list of URLs that are the most relevant. - - See the [Serper Dev website](https://serper.dev/) for more details. - """ - - def __init__( - self, - api_key: Optional[str] = None, - top_k: Optional[int] = 10, - allowed_domains: Optional[List[str]] = None, - search_params: Optional[Dict[str, Any]] = None, - ): - """ - :param api_key: API key for the SerperDev API. It can be - explicitly provided or automatically read from the - environment variable SERPERDEV_API_KEY (recommended). - :param top_k: Number of documents to return. - :param allowed_domains: List of domains to limit the search to. - :param search_params: Additional parameters passed to the SerperDev API. - For example, you can set 'num' to 20 to increase the number of search results. - See the [Serper Dev website](https://serper.dev/) for more details. - """ - if api_key is None: - try: - api_key = os.environ["SERPERDEV_API_KEY"] - except KeyError as e: - raise ValueError( - "SerperDevWebSearch expects an API key. " - "Set the SERPERDEV_API_KEY environment variable (recommended) or pass it explicitly." - ) from e - raise ValueError("API key for SerperDev API must be set.") - self.api_key = api_key - self.top_k = top_k - self.allowed_domains = allowed_domains - self.search_params = search_params or {} - - def to_dict(self) -> Dict[str, Any]: - """ - Serialize this component to a dictionary. - """ - return default_to_dict( - self, top_k=self.top_k, allowed_domains=self.allowed_domains, search_params=self.search_params - ) - - @component.output_types(documents=List[Document], links=List[str]) - def run(self, query: str): - """ - Search the SerperDev API for the given query and return the results as a list of Documents and a list of links. - - :param query: Query string. - """ - query_prepend = "OR ".join(f"site:{domain} " for domain in self.allowed_domains) if self.allowed_domains else "" - - payload = json.dumps( - {"q": query_prepend + query, "gl": "us", "hl": "en", "autocorrect": True, **self.search_params} - ) - headers = {"X-API-KEY": self.api_key, "Content-Type": "application/json"} - - try: - response = requests.post(SERPERDEV_BASE_URL, headers=headers, data=payload, timeout=30) - response.raise_for_status() # Will raise an HTTPError for bad responses - except requests.Timeout: - raise TimeoutError(f"Request to {self.__class__.__name__} timed out.") - - except requests.RequestException as e: - raise SerperDevError(f"An error occurred while querying {self.__class__.__name__}. Error: {e}") from e - - # If we reached this point, it means the request was successful and we can proceed - json_result = response.json() - - # we get the snippet from the json result and put it in the content field of the document - organic = [ - Document(meta={k: v for k, v in d.items() if k != "snippet"}, content=d["snippet"]) - for d in json_result["organic"] - ] - - # answer box is what search engine shows as a direct answer to the query - answer_box = [] - if "answerBox" in json_result: - answer_dict = json_result["answerBox"] - highlighted_answers = answer_dict.get("snippetHighlighted") - answer_box_content = None - # Check if highlighted_answers is a list and has at least one element - if isinstance(highlighted_answers, list) and len(highlighted_answers) > 0: - answer_box_content = highlighted_answers[0] - elif isinstance(highlighted_answers, str): - answer_box_content = highlighted_answers - if not answer_box_content: - for key in ["snippet", "answer", "title"]: - if key in answer_dict: - answer_box_content = answer_dict[key] - break - if answer_box_content: - answer_box = [ - Document( - content=answer_box_content, - meta={"title": answer_dict.get("title", ""), "link": answer_dict.get("link", "")}, - ) - ] - - # these are related questions that search engine shows - people_also_ask = [] - if "peopleAlsoAsk" in json_result: - for result in json_result["peopleAlsoAsk"]: - title = result.get("title", "") - people_also_ask.append( - Document( - content=result["snippet"] if result.get("snippet") else title, - meta={"title": title, "link": result.get("link", None)}, - ) - ) - - documents = answer_box + organic + people_also_ask - - links = [result["link"] for result in json_result["organic"]] - - logger.debug("Serper Dev returned %s documents for the query '%s'", len(documents), query) - return {"documents": documents[: self.top_k], "links": links[: self.top_k]} diff --git a/haystack/preview/components/writers/__init__.py b/haystack/preview/components/writers/__init__.py deleted file mode 100644 index 8328148352..0000000000 --- a/haystack/preview/components/writers/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from haystack.preview.components.writers.document_writer import DocumentWriter - -__all__ = ["DocumentWriter"] diff --git a/haystack/preview/components/writers/document_writer.py b/haystack/preview/components/writers/document_writer.py deleted file mode 100644 index 2ce45afde5..0000000000 --- a/haystack/preview/components/writers/document_writer.py +++ /dev/null @@ -1,68 +0,0 @@ -from typing import List, Optional, Dict, Any - -from haystack.preview import component, Document, default_from_dict, default_to_dict, DeserializationError -from haystack.preview.document_stores import DocumentStore, DuplicatePolicy, document_store - - -@component -class DocumentWriter: - """ - A component for writing documents to a DocumentStore. - """ - - def __init__(self, document_store: DocumentStore, policy: DuplicatePolicy = DuplicatePolicy.FAIL): - """ - Create a DocumentWriter component. - - :param policy: The policy to use when encountering duplicate documents (default is DuplicatePolicy.FAIL). - """ - self.document_store = document_store - self.policy = policy - - def _get_telemetry_data(self) -> Dict[str, Any]: - """ - Data that is sent to Posthog for usage analytics. - """ - return {"document_store": type(self.document_store).__name__} - - def to_dict(self) -> Dict[str, Any]: - """ - Serialize this component to a dictionary. - """ - return default_to_dict(self, document_store=self.document_store.to_dict(), policy=self.policy.name) - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "DocumentWriter": - """ - Deserialize this component from a dictionary. - """ - init_params = data.get("init_parameters", {}) - if "document_store" not in init_params: - raise DeserializationError("Missing 'document_store' in serialization data") - if "type" not in init_params["document_store"]: - raise DeserializationError("Missing 'type' in document store's serialization data") - if init_params["document_store"]["type"] not in document_store.registry: - raise DeserializationError(f"DocumentStore of type '{init_params['document_store']['type']}' not found.") - docstore_class = document_store.registry[init_params["document_store"]["type"]] - docstore = docstore_class.from_dict(init_params["document_store"]) - - data["init_parameters"]["document_store"] = docstore - data["init_parameters"]["policy"] = DuplicatePolicy[data["init_parameters"]["policy"]] - return default_from_dict(cls, data) - - @component.output_types(documents_written=int) - def run(self, documents: List[Document], policy: Optional[DuplicatePolicy] = None): - """ - Run DocumentWriter on the given input data. - - :param documents: A list of documents to write to the store. - :param policy: The policy to use when encountering duplicate documents. - :return: Number of documents written - - :raises ValueError: If the specified document store is not found. - """ - if policy is None: - policy = self.policy - - documents_written = self.document_store.write_documents(documents=documents, policy=policy) - return {"documents_written": documents_written} diff --git a/haystack/preview/dataclasses/__init__.py b/haystack/preview/dataclasses/__init__.py deleted file mode 100644 index f27204bc4f..0000000000 --- a/haystack/preview/dataclasses/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -from haystack.preview.dataclasses.document import Document -from haystack.preview.dataclasses.answer import ExtractedAnswer, GeneratedAnswer, Answer -from haystack.preview.dataclasses.byte_stream import ByteStream -from haystack.preview.dataclasses.chat_message import ChatMessage -from haystack.preview.dataclasses.chat_message import ChatRole -from haystack.preview.dataclasses.streaming_chunk import StreamingChunk - -__all__ = [ - "Document", - "ExtractedAnswer", - "GeneratedAnswer", - "Answer", - "ByteStream", - "ChatMessage", - "ChatRole", - "StreamingChunk", -] diff --git a/haystack/preview/dataclasses/answer.py b/haystack/preview/dataclasses/answer.py deleted file mode 100644 index ed3c1ae0c1..0000000000 --- a/haystack/preview/dataclasses/answer.py +++ /dev/null @@ -1,25 +0,0 @@ -from typing import Any, Dict, List, Optional -from dataclasses import dataclass -from haystack.preview.dataclasses.document import Document - - -@dataclass(frozen=True) -class Answer: - data: Any - query: str - metadata: Dict[str, Any] - - -@dataclass(frozen=True) -class ExtractedAnswer(Answer): - data: Optional[str] - document: Optional[Document] - probability: float - start: Optional[int] = None - end: Optional[int] = None - - -@dataclass(frozen=True) -class GeneratedAnswer(Answer): - data: str - documents: List[Document] diff --git a/haystack/preview/dataclasses/byte_stream.py b/haystack/preview/dataclasses/byte_stream.py deleted file mode 100644 index dd84e1c26b..0000000000 --- a/haystack/preview/dataclasses/byte_stream.py +++ /dev/null @@ -1,38 +0,0 @@ -from dataclasses import dataclass, field -from pathlib import Path -from typing import Optional, Dict, Any - - -@dataclass(frozen=True) -class ByteStream: - """ - Base data class representing a binary object in the Haystack API. - """ - - data: bytes - metadata: Dict[str, Any] = field(default_factory=dict, hash=False) - mime_type: Optional[str] = field(default=None) - - def to_file(self, destination_path: Path): - with open(destination_path, "wb") as fd: - fd.write(self.data) - - @classmethod - def from_file_path(cls, filepath: Path, mime_type: Optional[str] = None) -> "ByteStream": - """ - Create a ByteStream from the contents read from a file. - - :param filepath: A valid path to a file. - """ - with open(filepath, "rb") as fd: - return cls(data=fd.read(), mime_type=mime_type) - - @classmethod - def from_string(cls, text: str, encoding: str = "utf-8", mime_type: Optional[str] = None) -> "ByteStream": - """ - Create a ByteStream encoding a string. - - :param text: The string to encode - :param encoding: The encoding used to convert the string into bytes - """ - return cls(data=text.encode(encoding), mime_type=mime_type) diff --git a/haystack/preview/dataclasses/chat_message.py b/haystack/preview/dataclasses/chat_message.py deleted file mode 100644 index 08c61d6cfc..0000000000 --- a/haystack/preview/dataclasses/chat_message.py +++ /dev/null @@ -1,80 +0,0 @@ -from dataclasses import dataclass, field -from enum import Enum -from typing import Dict, Any, Optional - - -class ChatRole(str, Enum): - """Enumeration representing the roles within a chat.""" - - ASSISTANT = "assistant" - USER = "user" - SYSTEM = "system" - FUNCTION = "function" - - -@dataclass -class ChatMessage: - """ - Represents a message in a LLM chat conversation. - - :param content: The text content of the message. - :param role: The role of the entity sending the message. - :param name: The name of the function being called (only applicable for role FUNCTION). - :param metadata: Additional metadata associated with the message. - """ - - content: str - role: ChatRole - name: Optional[str] - metadata: Dict[str, Any] = field(default_factory=dict, hash=False) - - def is_from(self, role: ChatRole) -> bool: - """ - Check if the message is from a specific role. - - :param role: The role to check against. - :return: True if the message is from the specified role, False otherwise. - """ - return self.role == role - - @classmethod - def from_assistant(cls, content: str, metadata: Optional[Dict[str, Any]] = None) -> "ChatMessage": - """ - Create a message from the assistant. - - :param content: The text content of the message. - :param metadata: Additional metadata associated with the message. - :return: A new ChatMessage instance. - """ - return cls(content, ChatRole.ASSISTANT, None, metadata or {}) - - @classmethod - def from_user(cls, content: str) -> "ChatMessage": - """ - Create a message from the user. - - :param content: The text content of the message. - :return: A new ChatMessage instance. - """ - return cls(content, ChatRole.USER, None) - - @classmethod - def from_system(cls, content: str) -> "ChatMessage": - """ - Create a message from the system. - - :param content: The text content of the message. - :return: A new ChatMessage instance. - """ - return cls(content, ChatRole.SYSTEM, None) - - @classmethod - def from_function(cls, content: str, name: str) -> "ChatMessage": - """ - Create a message from a function call. - - :param content: The text content of the message. - :param name: The name of the function being called. - :return: A new ChatMessage instance. - """ - return cls(content, ChatRole.FUNCTION, name) diff --git a/haystack/preview/dataclasses/document.py b/haystack/preview/dataclasses/document.py deleted file mode 100644 index 168951edc0..0000000000 --- a/haystack/preview/dataclasses/document.py +++ /dev/null @@ -1,186 +0,0 @@ -import io -import hashlib -import logging -from dataclasses import asdict, dataclass, field, fields -from typing import Any, Dict, List, Optional - -import numpy -import pandas - -from haystack.preview.dataclasses.byte_stream import ByteStream - -logger = logging.getLogger(__name__) - - -class _BackwardCompatible(type): - """ - Metaclass that handles Document backward compatibility. - """ - - def __call__(cls, *args, **kwargs): - """ - Called before Document.__init__, will remap legacy fields to new ones. - Also handles building a Document from a flattened dictionary. - """ - # Move `content` to new fields depending on the type - content = kwargs.get("content") - if isinstance(content, pandas.DataFrame): - kwargs["dataframe"] = content - del kwargs["content"] - - # Not used anymore - if "content_type" in kwargs: - del kwargs["content_type"] - - # Embedding were stored as NumPy arrays in 1.x, so we convert it to the new type - if isinstance(embedding := kwargs.get("embedding"), numpy.ndarray): - kwargs["embedding"] = embedding.tolist() - - # id_hash_keys is not used anymore - if "id_hash_keys" in kwargs: - del kwargs["id_hash_keys"] - - return super().__call__(*args, **kwargs) - - -@dataclass -class Document(metaclass=_BackwardCompatible): - """ - Base data class containing some data to be queried. - Can contain text snippets, tables, and file paths to images or audios. - Documents can be sorted by score and saved to/from dictionary and JSON. - - :param id: Unique identifier for the document. When not set, it's generated based on the Document fields' values. - :param content: Text of the document, if the document contains text. - :param dataframe: Pandas dataframe with the document's content, if the document contains tabular data. - :param blob: Binary data associated with the document, if the document has any binary data associated with it. - :param meta: Additional custom metadata for the document. Must be JSON-serializable. - :param score: Score of the document. Used for ranking, usually assigned by retrievers. - :param embedding: Vector representation of the document. - """ - - id: str = field(default="") - content: Optional[str] = field(default=None) - dataframe: Optional[pandas.DataFrame] = field(default=None) - blob: Optional[ByteStream] = field(default=None) - meta: Dict[str, Any] = field(default_factory=dict) - score: Optional[float] = field(default=None) - embedding: Optional[List[float]] = field(default=None) - - def __repr__(self): - fields = [] - if self.content is not None: - fields.append( - f"content: '{self.content}'" if len(self.content) < 100 else f"content: '{self.content[:100]}...'" - ) - if self.dataframe is not None: - fields.append(f"dataframe: {self.dataframe.shape}") - if self.blob is not None: - fields.append(f"blob: {len(self.blob.data)} bytes") - if len(self.meta) > 0: - fields.append(f"meta: {self.meta}") - if self.score is not None: - fields.append(f"score: {self.score}") - if self.embedding is not None: - fields.append(f"embedding: vector of size {len(self.embedding)}") - fields_str = ", ".join(fields) - return f"{self.__class__.__name__}(id={self.id}, {fields_str})" - - def __eq__(self, other): - """ - Compares Documents for equality. - Two Documents are considered equals if their dictionary representation is identical. - """ - if type(self) != type(other): - return False - return self.to_dict() == other.to_dict() - - def __post_init__(self): - """ - Generate the ID based on the init parameters. - """ - # Generate an id only if not explicitly set - self.id = self.id or self._create_id() - - def _create_id(self): - """ - Creates a hash of the given content that acts as the document's ID. - """ - text = self.content or None - dataframe = self.dataframe.to_json() if self.dataframe is not None else None - blob = self.blob.data if self.blob is not None else None - mime_type = self.blob.mime_type if self.blob is not None else None - meta = self.meta or {} - embedding = self.embedding if self.embedding is not None else None - data = f"{text}{dataframe}{blob}{mime_type}{meta}{embedding}" - return hashlib.sha256(data.encode("utf-8")).hexdigest() - - def to_dict(self, flatten=True) -> Dict[str, Any]: - """ - Converts Document into a dictionary. - `dataframe` and `blob` fields are converted to JSON-serializable types. - - :param flatten: Whether to flatten `meta` field or not. Defaults to `True` to be backward-compatible with Haystack 1.x. - """ - data = asdict(self) - if (dataframe := data.get("dataframe")) is not None: - data["dataframe"] = dataframe.to_json() - if (blob := data.get("blob")) is not None: - data["blob"] = {"data": list(blob["data"]), "mime_type": blob["mime_type"]} - - if flatten: - meta = data.pop("meta") - return {**data, **meta} - - return data - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "Document": - """ - Creates a new Document object from a dictionary. - `dataframe` and `blob` fields are converted to their original types. - """ - if (dataframe := data.get("dataframe")) is not None: - data["dataframe"] = pandas.read_json(io.StringIO(dataframe)) - if blob := data.get("blob"): - data["blob"] = ByteStream(data=bytes(blob["data"]), mime_type=blob["mime_type"]) - # Store metadata for a moment while we try un-flattening allegedly flatten metadata. - # We don't expect both a `meta=` keyword and flatten metadata keys so we'll raise a - # ValueError later if this is the case. - meta = data.pop("meta", {}) - # Unflatten metadata if it was flattened. We assume any keyword argument that's not - # a document field is a metadata key. We treat legacy fields as document fields - # for backward compatibility. - flatten_meta = {} - legacy_fields = ["content_type", "id_hash_keys"] - document_fields = legacy_fields + [f.name for f in fields(cls)] - for key in list(data.keys()): - if key not in document_fields: - flatten_meta[key] = data.pop(key) - - # We don't support passing both flatten keys and the `meta` keyword parameter - if meta and flatten_meta: - raise ValueError( - "You can pass either the 'meta' parameter or flattened metadata keys as keyword arguments, " - "but currently you're passing both. Pass either the 'meta' parameter or flattened metadata keys." - ) - - # Finally put back all the metadata - return cls(**data, meta={**meta, **flatten_meta}) - - @property - def content_type(self): - """ - Returns the type of the content for the document. - This is necessary to keep backward compatibility with 1.x. - A ValueError will be raised if both `text` and `dataframe` fields are set - or both are missing. - """ - if self.content is not None and self.dataframe is not None: - raise ValueError("Both text and dataframe are set.") - - if self.content is not None: - return "text" - elif self.dataframe is not None: - return "table" - raise ValueError("Neither text nor dataframe is set.") diff --git a/haystack/preview/dataclasses/streaming_chunk.py b/haystack/preview/dataclasses/streaming_chunk.py deleted file mode 100644 index 1245560431..0000000000 --- a/haystack/preview/dataclasses/streaming_chunk.py +++ /dev/null @@ -1,17 +0,0 @@ -from dataclasses import dataclass, field -from typing import Dict, Any - - -@dataclass -class StreamingChunk: - """ - The StreamingChunk class encapsulates a segment of streamed content along with - associated metadata. This structure facilitates the handling and processing of - streamed data in a systematic manner. - - :param content: The content of the message chunk as a string. - :param metadata: A dictionary containing metadata related to the message chunk. - """ - - content: str - metadata: Dict[str, Any] = field(default_factory=dict, hash=False) diff --git a/haystack/preview/document_stores/__init__.py b/haystack/preview/document_stores/__init__.py deleted file mode 100644 index 632c1f448f..0000000000 --- a/haystack/preview/document_stores/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -from haystack.preview.document_stores.protocols import DocumentStore, DuplicatePolicy -from haystack.preview.document_stores.in_memory.document_store import InMemoryDocumentStore -from haystack.preview.document_stores.errors import DocumentStoreError, DuplicateDocumentError, MissingDocumentError -from haystack.preview.document_stores.decorator import document_store - -__all__ = [ - "DocumentStore", - "DuplicatePolicy", - "InMemoryDocumentStore", - "DocumentStoreError", - "DuplicateDocumentError", - "MissingDocumentError", - "document_store", -] diff --git a/haystack/preview/document_stores/decorator.py b/haystack/preview/document_stores/decorator.py deleted file mode 100644 index c82ccf91d9..0000000000 --- a/haystack/preview/document_stores/decorator.py +++ /dev/null @@ -1,39 +0,0 @@ -import logging - -logger = logging.getLogger(__name__) - - -class _DocumentStore: - """ - Marks a class as an Haystack _DocumentStore. - All classes decorated with @document_store will be registered here and can be used in Haystack Pipelines. - """ - - def __init__(self): - self.registry = {} - - def _decorate(self, cls): - cls.__haystack_document_store__ = True - - classname = f"{cls.__module__}.{cls.__name__}" - if classname in self.registry: - logger.error( - "DocumentStore %s is already registered. Previous imported from '%s', new imported from '%s'", - classname, - self.registry[classname], - cls, - ) - - self.registry[classname] = cls - logger.debug("Registered DocumentStore %s", cls) - - return cls - - def __call__(self, cls=None): - if cls: - return self._decorate(cls) - - return self._decorate - - -document_store = _DocumentStore() diff --git a/haystack/preview/document_stores/errors.py b/haystack/preview/document_stores/errors.py deleted file mode 100644 index c345b04e50..0000000000 --- a/haystack/preview/document_stores/errors.py +++ /dev/null @@ -1,10 +0,0 @@ -class DocumentStoreError(Exception): - pass - - -class DuplicateDocumentError(DocumentStoreError): - pass - - -class MissingDocumentError(DocumentStoreError): - pass diff --git a/haystack/preview/document_stores/in_memory/__init__.py b/haystack/preview/document_stores/in_memory/__init__.py deleted file mode 100644 index 5b3644431a..0000000000 --- a/haystack/preview/document_stores/in_memory/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from haystack.preview.document_stores.in_memory.document_store import InMemoryDocumentStore - -__all__ = ["InMemoryDocumentStore"] diff --git a/haystack/preview/document_stores/in_memory/document_store.py b/haystack/preview/document_stores/in_memory/document_store.py deleted file mode 100644 index f00c4199bd..0000000000 --- a/haystack/preview/document_stores/in_memory/document_store.py +++ /dev/null @@ -1,328 +0,0 @@ -import re -from typing import Literal, Any, Dict, List, Optional, Iterable - -import logging - -import numpy as np -import rank_bm25 -from tqdm.auto import tqdm - -from haystack.preview import default_from_dict, default_to_dict -from haystack.preview.document_stores.decorator import document_store -from haystack.preview.dataclasses import Document -from haystack.preview.document_stores.protocols import DuplicatePolicy -from haystack.preview.utils.filters import document_matches_filter, convert -from haystack.preview.document_stores.errors import DuplicateDocumentError, DocumentStoreError -from haystack.preview.utils import expit - -logger = logging.getLogger(__name__) - -# document scores are essentially unbounded and will be scaled to values between 0 and 1 if scale_score is set to -# True (default). Scaling uses the expit function (inverse of the logit function) after applying a scaling factor -# (e.g., BM25_SCALING_FACTOR for the bm25_retrieval method). -# Larger scaling factor decreases scaled scores. For example, an input of 10 is scaled to 0.99 with BM25_SCALING_FACTOR=2 -# but to 0.78 with BM25_SCALING_FACTOR=8 (default). The defaults were chosen empirically. Increase the default if most -# unscaled scores are larger than expected (>30) and otherwise would incorrectly all be mapped to scores ~1. -BM25_SCALING_FACTOR = 8 -DOT_PRODUCT_SCALING_FACTOR = 100 - - -@document_store -class InMemoryDocumentStore: - """ - Stores data in-memory. It's ephemeral and cannot be saved to disk. - """ - - def __init__( - self, - bm25_tokenization_regex: str = r"(?u)\b\w\w+\b", - bm25_algorithm: Literal["BM25Okapi", "BM25L", "BM25Plus"] = "BM25Okapi", - bm25_parameters: Optional[Dict] = None, - embedding_similarity_function: Literal["dot_product", "cosine"] = "dot_product", - ): - """ - Initializes the DocumentStore. - - :param bm25_tokenization_regex: The regular expression used to tokenize the text for BM25 retrieval. - :param bm25_algorithm: The BM25 algorithm to use. One of "BM25Okapi", "BM25L", or "BM25Plus". - :param bm25_parameters: Parameters for BM25 implementation in a dictionary format. - For example: {'k1':1.5, 'b':0.75, 'epsilon':0.25} - You can learn more about these parameters by visiting https://github.com/dorianbrown/rank_bm25. - By default, no parameters are set. - :param embedding_similarity_function: The similarity function used to compare Documents embeddings. - One of "dot_product" (default) or "cosine". - To choose the most appropriate function, look for information about your embedding model. - """ - self.storage: Dict[str, Document] = {} - self._bm25_tokenization_regex = bm25_tokenization_regex - self.tokenizer = re.compile(bm25_tokenization_regex).findall - algorithm_class = getattr(rank_bm25, bm25_algorithm) - if algorithm_class is None: - raise ValueError(f"BM25 algorithm '{bm25_algorithm}' not found.") - self.bm25_algorithm = algorithm_class - self.bm25_parameters = bm25_parameters or {} - self.embedding_similarity_function = embedding_similarity_function - - def to_dict(self) -> Dict[str, Any]: - """ - Serializes this store to a dictionary. - """ - return default_to_dict( - self, - bm25_tokenization_regex=self._bm25_tokenization_regex, - bm25_algorithm=self.bm25_algorithm.__name__, - bm25_parameters=self.bm25_parameters, - embedding_similarity_function=self.embedding_similarity_function, - ) - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "InMemoryDocumentStore": - """ - Deserializes the store from a dictionary. - """ - return default_from_dict(cls, data) - - def count_documents(self) -> int: - """ - Returns the number of how many documents are present in the DocumentStore. - """ - return len(self.storage.keys()) - - def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: - """ - Returns the documents that match the filters provided. - - For a detailed specification of the filters, refer to the DocumentStore.filter_documents() protocol documentation. - - :param filters: The filters to apply to the document list. - :return: A list of Documents that match the given filters. - """ - if filters: - if "operator" not in filters: - filters = convert(filters) - return [doc for doc in self.storage.values() if document_matches_filter(filters=filters, document=doc)] - return list(self.storage.values()) - - def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL) -> int: - """ - Writes (or overwrites) documents into the DocumentStore. - - :param documents: A list of documents. - :param policy: Documents with the same ID count as duplicates. When duplicates are met, - the DocumentStore can: - - skip: keep the existing document and ignore the new one. - - overwrite: remove the old document and write the new one. - - fail: an error is raised. - :raises DuplicateError: Exception trigger on duplicate document if `policy=DuplicatePolicy.FAIL` - :return: None - """ - if ( - not isinstance(documents, Iterable) - or isinstance(documents, str) - or any(not isinstance(doc, Document) for doc in documents) - ): - raise ValueError("Please provide a list of Documents.") - - written_documents = len(documents) - for document in documents: - if policy != DuplicatePolicy.OVERWRITE and document.id in self.storage.keys(): - if policy == DuplicatePolicy.FAIL: - raise DuplicateDocumentError(f"ID '{document.id}' already exists.") - if policy == DuplicatePolicy.SKIP: - logger.warning("ID '%s' already exists", document.id) - written_documents -= 1 - continue - self.storage[document.id] = document - return written_documents - - def delete_documents(self, document_ids: List[str]) -> None: - """ - Deletes all documents with matching document_ids from the DocumentStore. - :param object_ids: The object_ids to delete. - """ - for doc_id in document_ids: - if doc_id not in self.storage.keys(): - continue - del self.storage[doc_id] - - def bm25_retrieval( - self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, scale_score: bool = False - ) -> List[Document]: - """ - Retrieves documents that are most relevant to the query using BM25 algorithm. - - :param query: The query string. - :param filters: A dictionary with filters to narrow down the search space. - :param top_k: The number of top documents to retrieve. Default is 10. - :param scale_score: Whether to scale the scores of the retrieved documents. Default is False. - :return: A list of the top_k documents most relevant to the query. - """ - if not query: - raise ValueError("Query should be a non-empty string") - - content_type_filter = { - "operator": "OR", - "conditions": [ - {"field": "content", "operator": "!=", "value": None}, - {"field": "dataframe", "operator": "!=", "value": None}, - ], - } - if filters: - if "operator" not in filters: - filters = convert(filters) - filters = {"operator": "AND", "conditions": [content_type_filter, filters]} - else: - filters = content_type_filter - all_documents = self.filter_documents(filters=filters) - - # Lowercase all documents - lower_case_documents = [] - for doc in all_documents: - if doc.content is None and doc.dataframe is None: - logger.info("Document '%s' has no text or dataframe content. Skipping it.", doc.id) - else: - if doc.content is not None: - lower_case_documents.append(doc.content.lower()) - if doc.dataframe is not None: - logger.warning( - "Document '%s' has both text and dataframe content. " - "Using text content and skipping dataframe content.", - doc.id, - ) - continue - if doc.dataframe is not None: - str_content = doc.dataframe.astype(str) - csv_content = str_content.to_csv(index=False) - lower_case_documents.append(csv_content.lower()) - - # Tokenize the entire content of the DocumentStore - tokenized_corpus = [ - self.tokenizer(doc) for doc in tqdm(lower_case_documents, unit=" docs", desc="Ranking by BM25...") - ] - if len(tokenized_corpus) == 0: - logger.info("No documents found for BM25 retrieval. Returning empty list.") - return [] - - # initialize BM25 - bm25_scorer = self.bm25_algorithm(tokenized_corpus, **self.bm25_parameters) - # tokenize query - tokenized_query = self.tokenizer(query.lower()) - # get scores for the query against the corpus - docs_scores = bm25_scorer.get_scores(tokenized_query) - if scale_score: - docs_scores = [expit(float(score / BM25_SCALING_FACTOR)) for score in docs_scores] - # get the last top_k indexes and reverse them - top_docs_positions = np.argsort(docs_scores)[-top_k:][::-1] - - # Create documents with the BM25 score to return them - return_documents = [] - for i in top_docs_positions: - doc = all_documents[i] - doc_fields = doc.to_dict() - doc_fields["score"] = docs_scores[i] - return_document = Document.from_dict(doc_fields) - return_documents.append(return_document) - return return_documents - - def embedding_retrieval( - self, - query_embedding: List[float], - filters: Optional[Dict[str, Any]] = None, - top_k: int = 10, - scale_score: bool = False, - return_embedding: bool = False, - ) -> List[Document]: - """ - Retrieves documents that are most similar to the query embedding using a vector similarity metric. - - :param query_embedding: Embedding of the query. - :param filters: A dictionary with filters to narrow down the search space. - :param top_k: The number of top documents to retrieve. Default is 10. - :param scale_score: Whether to scale the scores of the retrieved Documents. Default is False. - :param return_embedding: Whether to return the embedding of the retrieved Documents. Default is False. - :return: A list of the top_k documents most relevant to the query. - """ - if len(query_embedding) == 0 or not isinstance(query_embedding[0], float): - raise ValueError("query_embedding should be a non-empty list of floats.") - - filters = filters or {} - all_documents = self.filter_documents(filters=filters) - - documents_with_embeddings = [doc for doc in all_documents if doc.embedding is not None] - if len(documents_with_embeddings) == 0: - logger.warning( - "No Documents found with embeddings. Returning empty list. " - "To generate embeddings, use a DocumentEmbedder." - ) - return [] - elif len(documents_with_embeddings) < len(all_documents): - logger.info( - "Skipping some Documents that don't have an embedding. " - "To generate embeddings, use a DocumentEmbedder." - ) - - scores = self._compute_query_embedding_similarity_scores( - embedding=query_embedding, documents=documents_with_embeddings, scale_score=scale_score - ) - - # create Documents with the similarity score for the top k results - top_documents = [] - for doc, score in sorted(zip(documents_with_embeddings, scores), key=lambda x: x[1], reverse=True)[:top_k]: - doc_fields = doc.to_dict() - doc_fields["score"] = score - if return_embedding is False: - doc_fields["embedding"] = None - top_documents.append(Document.from_dict(doc_fields)) - - return top_documents - - def _compute_query_embedding_similarity_scores( - self, embedding: List[float], documents: List[Document], scale_score: bool = False - ) -> List[float]: - """ - Computes the similarity scores between the query embedding and the embeddings of the documents. - - :param embedding: Embedding of the query. - :param documents: A list of Documents. - :param scale_score: Whether to scale the scores of the Documents. Default is False. - :return: A list of scores. - """ - - query_embedding = np.array(embedding) - if query_embedding.ndim == 1: - query_embedding = np.expand_dims(a=query_embedding, axis=0) - - try: - document_embeddings = np.array([doc.embedding for doc in documents]) - except ValueError as e: - if "inhomogeneous shape" in str(e): - raise DocumentStoreError( - "The embedding size of all Documents should be the same. " - "Please make sure that the Documents have been embedded with the same model." - ) from e - raise e - if document_embeddings.ndim == 1: - document_embeddings = np.expand_dims(a=document_embeddings, axis=0) - - if self.embedding_similarity_function == "cosine": - # cosine similarity is a normed dot product - query_embedding /= np.linalg.norm(x=query_embedding, axis=1, keepdims=True) - document_embeddings /= np.linalg.norm(x=document_embeddings, axis=1, keepdims=True) - - try: - scores = np.dot(a=query_embedding, b=document_embeddings.T)[0].tolist() - except ValueError as e: - if "shapes" in str(e) and "not aligned" in str(e): - raise DocumentStoreError( - "The embedding size of the query should be the same as the embedding size of the Documents. " - "Please make sure that the query has been embedded with the same model as the Documents." - ) from e - raise e - - if scale_score: - if self.embedding_similarity_function == "dot_product": - scores = [expit(float(score / DOT_PRODUCT_SCALING_FACTOR)) for score in scores] - elif self.embedding_similarity_function == "cosine": - scores = [(score + 1) / 2 for score in scores] - - return scores diff --git a/haystack/preview/document_stores/protocols.py b/haystack/preview/document_stores/protocols.py deleted file mode 100644 index 6a27f19551..0000000000 --- a/haystack/preview/document_stores/protocols.py +++ /dev/null @@ -1,142 +0,0 @@ -from typing import Protocol, Optional, Dict, Any, List -import logging -from enum import Enum - -from haystack.preview.dataclasses import Document - - -# Ellipsis are needed for the type checker, it's safe to disable module-wide -# pylint: disable=unnecessary-ellipsis - -logger = logging.getLogger(__name__) - - -class DuplicatePolicy(Enum): - SKIP = "skip" - OVERWRITE = "overwrite" - FAIL = "fail" - - -class DocumentStore(Protocol): - """ - Stores Documents to be used by the components of a Pipeline. - - Classes implementing this protocol often store the documents permanently and allow specialized components to - perform retrieval on them, either by embedding, by keyword, hybrid, and so on, depending on the backend used. - - In order to retrieve documents, consider using a Retriever that supports the DocumentStore implementation that - you're using. - """ - - def to_dict(self) -> Dict[str, Any]: - """ - Serializes this store to a dictionary. - """ - ... - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "DocumentStore": - """ - Deserializes the store from a dictionary. - """ - ... - - def count_documents(self) -> int: - """ - Returns the number of documents stored. - """ - ... - - def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: - """ - Returns the documents that match the filters provided. - - Filters are defined as nested dictionaries that can be of two types: - - Comparison - - Logic - - Comparison dictionaries must contain the keys: - - - `field` - - `operator` - - `value` - - Logic dictionaries must contain the keys: - - - `operator` - - `conditions` - - The `conditions` key must be a list of dictionaries, either of type Comparison or Logic. - - The `operator` value in Comparison dictionaries must be one of: - - - `==` - - `!=` - - `>` - - `>=` - - `<` - - `<=` - - `in` - - `not in` - - The `operator` values in Logic dictionaries must be one of: - - - `NOT` - - `OR` - - `AND` - - - A simple filter: - ```python - filters = {"field": "meta.type", "operator": "==", "value": "article"} - ``` - - A more complex filter: - ```python - filters = { - "operator": "AND", - "conditions": [ - {"field": "meta.type", "operator": "==", "value": "article"}, - {"field": "meta.date", "operator": ">=", "value": 1420066800}, - {"field": "meta.date", "operator": "<", "value": 1609455600}, - {"field": "meta.rating", "operator": ">=", "value": 3}, - { - "operator": "OR", - "conditions": [ - {"field": "meta.genre", "operator": "in", "value": ["economy", "politics"]}, - {"field": "meta.publisher", "operator": "==", "value": "nytimes"}, - ], - }, - ], - } - - :param filters: the filters to apply to the document list. - :return: a list of Documents that match the given filters. - """ - ... - - def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL) -> int: - """ - Writes (or overwrites) documents into the DocumentStore. - - :param documents: a list of documents. - :param policy: documents with the same ID count as duplicates. When duplicates are met, - the DocumentStore can: - - skip: keep the existing document and ignore the new one. - - overwrite: remove the old document and write the new one. - - fail: an error is raised - :raises DuplicateError: Exception trigger on duplicate document if `policy=DuplicatePolicy.FAIL` - :return: The number of documents that was written. - If DuplicatePolicy.OVERWRITE is used, this number is always equal to the number of documents in input. - If DuplicatePolicy.SKIP is used, this number can be lower than the number of documents in the input list. - """ - ... - - def delete_documents(self, document_ids: List[str]) -> None: - """ - Deletes all documents with a matching document_ids from the DocumentStore. - Fails with `MissingDocumentError` if no document with this id is present in the DocumentStore. - - :param object_ids: the object_ids to delete - """ - ... diff --git a/haystack/preview/errors.py b/haystack/preview/errors.py deleted file mode 100644 index c7a6c47d6d..0000000000 --- a/haystack/preview/errors.py +++ /dev/null @@ -1,2 +0,0 @@ -class FilterError(Exception): - pass diff --git a/haystack/preview/lazy_imports.py b/haystack/preview/lazy_imports.py deleted file mode 100644 index 5f474beef9..0000000000 --- a/haystack/preview/lazy_imports.py +++ /dev/null @@ -1,44 +0,0 @@ -from typing import Optional, Type -from types import TracebackType -from lazy_imports.try_import import _DeferredImportExceptionContextManager - - -DEFAULT_IMPORT_ERROR_MSG = "Try 'pip install {}'" - - -class LazyImport(_DeferredImportExceptionContextManager): - """ - Wrapper on top of lazy_import's _DeferredImportExceptionContextManager that adds the possibility to customize the - error messages. - """ - - def __init__(self, message: str = DEFAULT_IMPORT_ERROR_MSG) -> None: - super().__init__() - self.import_error_msg = message - - def __exit__( - self, exc_type: Optional[Type[Exception]], exc_value: Optional[Exception], traceback: Optional[TracebackType] - ) -> Optional[bool]: - """Exit the context manager. - - Args: - exc_type: - Raised exception type. :obj:`None` if nothing is raised. - exc_value: - Raised exception object. :obj:`None` if nothing is raised. - traceback: - Associated traceback. :obj:`None` if nothing is raised. - - Returns: - :obj:`None` if nothing is deferred, otherwise :obj:`True`. - :obj:`True` will suppress any exceptions avoiding them from propagating. - - """ - if isinstance(exc_value, ImportError): - message = ( - f"Failed to import '{exc_value.name}'. {self.import_error_msg.format(exc_value.name)}. " - f"Original error: {exc_value}" - ) - self._deferred = (exc_value, message) - return True - return None diff --git a/haystack/preview/marshal/__init__.py b/haystack/preview/marshal/__init__.py deleted file mode 100644 index f737be0574..0000000000 --- a/haystack/preview/marshal/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from haystack.preview.marshal.protocol import Marshaller -from haystack.preview.marshal.yaml import YamlMarshaller - -__all__ = ["Marshaller", "YamlMarshaller"] diff --git a/haystack/preview/marshal/protocol.py b/haystack/preview/marshal/protocol.py deleted file mode 100644 index 06663b7534..0000000000 --- a/haystack/preview/marshal/protocol.py +++ /dev/null @@ -1,9 +0,0 @@ -from typing import Protocol, Dict, Any, Union - - -class Marshaller(Protocol): - def marshal(self, dict_: Dict[str, Any]) -> str: - ... - - def unmarshal(self, data_: Union[str, bytes, bytearray]) -> Dict[str, Any]: - ... diff --git a/haystack/preview/marshal/yaml.py b/haystack/preview/marshal/yaml.py deleted file mode 100644 index 5fca27fb6f..0000000000 --- a/haystack/preview/marshal/yaml.py +++ /dev/null @@ -1,11 +0,0 @@ -from typing import Dict, Any, Union - -import yaml - - -class YamlMarshaller: - def marshal(self, dict_: Dict[str, Any]) -> str: - return yaml.dump(dict_) - - def unmarshal(self, data_: Union[str, bytes, bytearray]) -> Dict[str, Any]: - return yaml.safe_load(data_) diff --git a/haystack/preview/pipeline.py b/haystack/preview/pipeline.py deleted file mode 100644 index 275277295e..0000000000 --- a/haystack/preview/pipeline.py +++ /dev/null @@ -1,99 +0,0 @@ -from typing import Any, Dict, Optional, Union, TextIO -from pathlib import Path -import datetime -import logging -import canals - -from haystack.preview.telemetry import pipeline_running -from haystack.preview.marshal import Marshaller, YamlMarshaller - - -DEFAULT_MARSHALLER = YamlMarshaller() -logger = logging.getLogger(__name__) - - -class Pipeline(canals.Pipeline): - def __init__( - self, - metadata: Optional[Dict[str, Any]] = None, - max_loops_allowed: int = 100, - debug_path: Union[Path, str] = Path(".haystack_debug/"), - ): - """ - Creates the Pipeline. - - Args: - metadata: arbitrary dictionary to store metadata about this pipeline. Make sure all the values contained in - this dictionary can be serialized and deserialized if you wish to save this pipeline to file with - `save_pipelines()/load_pipelines()`. - max_loops_allowed: how many times the pipeline can run the same node before throwing an exception. - debug_path: when debug is enabled in `run()`, where to save the debug data. - """ - self._telemetry_runs = 0 - self._last_telemetry_sent: Optional[datetime.datetime] = None - super().__init__(metadata=metadata, max_loops_allowed=max_loops_allowed, debug_path=debug_path) - - def run(self, data: Dict[str, Any], debug: bool = False) -> Dict[str, Any]: - """ - Runs the pipeline. - - :params data: the inputs to give to the input components of the Pipeline. - :params debug: whether to collect and return debug information. - - :returns: A dictionary with the outputs of the output components of the Pipeline. - - :raises PipelineRuntimeError: if the any of the components fail or return unexpected output. - """ - pipeline_running(self) - return super().run(data=data, debug=debug) - - def dumps(self, marshaller: Marshaller = DEFAULT_MARSHALLER) -> str: - """ - Returns the string representation of this pipeline according to the - format dictated by the `Marshaller` in use. - - :params marshaller: The Marshaller used to create the string representation. Defaults to - `YamlMarshaller` - - :returns: A string representing the pipeline. - """ - return marshaller.marshal(self.to_dict()) - - def dump(self, fp: TextIO, marshaller: Marshaller = DEFAULT_MARSHALLER): - """ - Writes the string representation of this pipeline to the file-like object - passed in the `fp` argument. - - :params fp: A file-like object ready to be written to. - :params marshaller: The Marshaller used to create the string representation. Defaults to - `YamlMarshaller`. - """ - fp.write(marshaller.marshal(self.to_dict())) - - @classmethod - def loads(cls, data: Union[str, bytes, bytearray], marshaller: Marshaller = DEFAULT_MARSHALLER) -> "Pipeline": - """ - Creates a `Pipeline` object from the string representation passed in the `data` argument. - - :params data: The string representation of the pipeline, can be `str`, `bytes` or `bytearray`. - :params marshaller: the Marshaller used to create the string representation. Defaults to - `YamlMarshaller` - - :returns: A `Pipeline` object. - """ - return cls.from_dict(marshaller.unmarshal(data)) - - @classmethod - def load(cls, fp: TextIO, marshaller: Marshaller = DEFAULT_MARSHALLER) -> "Pipeline": - """ - Creates a `Pipeline` object from the string representation read from the file-like - object passed in the `fp` argument. - - :params data: The string representation of the pipeline, can be `str`, `bytes` or `bytearray`. - :params fp: A file-like object ready to be read from. - :params marshaller: the Marshaller used to create the string representation. Defaults to - `YamlMarshaller` - - :returns: A `Pipeline` object. - """ - return cls.from_dict(marshaller.unmarshal(fp.read())) diff --git a/haystack/preview/telemetry/__init__.py b/haystack/preview/telemetry/__init__.py deleted file mode 100644 index be32ab8102..0000000000 --- a/haystack/preview/telemetry/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from haystack.preview.telemetry._telemetry import pipeline_running, tutorial_running diff --git a/haystack/preview/telemetry/_environment.py b/haystack/preview/telemetry/_environment.py deleted file mode 100644 index c450c19320..0000000000 --- a/haystack/preview/telemetry/_environment.py +++ /dev/null @@ -1,106 +0,0 @@ -# pylint: disable=global-statement -import logging -import os -import platform -import sys -from typing import Optional, Dict, Any - -from haystack.preview.version import __version__ - -logger = logging.getLogger(__name__) - - -# This value cannot change during the lifetime of the process -_IS_DOCKER_CACHE = None - - -def _in_podman() -> bool: - """ - Podman run would create the file /run/.containernv, see: - https://github.com/containers/podman/blob/main/docs/source/markdown/podman-run.1.md.in#L31 - """ - return os.path.exists("/run/.containerenv") - - -def _has_dockerenv() -> bool: - """ - This might not work anymore at some point (even if it's been a while now), see: - https://github.com/moby/moby/issues/18355#issuecomment-220484748 - """ - return os.path.exists("/.dockerenv") - - -def _has_docker_cgroup_v1() -> bool: - """ - This only works with cgroups v1 - """ - path = "/proc/self/cgroup" # 'self' should be always symlinked to the actual PID - return os.path.isfile(path) and any("docker" in line for line in open(path)) - - -def _has_docker_cgroup_v2() -> bool: - """ - cgroups v2 version, inspired from - https://github.com/jenkinsci/docker-workflow-plugin/blob/master/src/main/java/org/jenkinsci/plugins/docker/workflow/client/DockerClient.java - """ - path = "/proc/self/mountinfo" # 'self' should be always symlinked to the actual PID - return os.path.isfile(path) and any("/docker/containers/" in line for line in open(path)) - - -def _is_containerized() -> Optional[bool]: - """ - This code is based on the popular 'is-docker' package for node.js - """ - global _IS_DOCKER_CACHE - - if _IS_DOCKER_CACHE is None: - _IS_DOCKER_CACHE = _in_podman() or _has_dockerenv() or _has_docker_cgroup_v1() or _has_docker_cgroup_v2() - - return _IS_DOCKER_CACHE - - -def collect_system_specs() -> Dict[str, Any]: - """ - Collects meta data about the setup that is used with Haystack, such as: - operating system, python version, Haystack version, transformers version, - pytorch version, number of GPUs, execution environment. - - These values are highly unlikely to change during the runtime of the pipeline, - so they're collected only once. - """ - specs = { - "libraries.haystack": __version__, - "os.containerized": _is_containerized(), - "os.version": platform.release(), - "os.family": platform.system(), - "os.machine": platform.machine(), - "python.version": platform.python_version(), - "hardware.cpus": os.cpu_count(), - "hardware.gpus": 0, - "libraries.transformers": False, - "libraries.torch": False, - "libraries.cuda": False, - "libraries.pytest": sys.modules["pytest"].__version__ if "pytest" in sys.modules.keys() else False, - "libraries.ipython": sys.modules["ipython"].__version__ if "ipython" in sys.modules.keys() else False, - "libraries.colab": sys.modules["google.colab"].__version__ if "google.colab" in sys.modules.keys() else False, - } - - # Try to find out transformer's version - try: - import transformers - - specs["libraries.transformers"] = transformers.__version__ - except ImportError: - pass - - # Try to find out torch's version and info on potential GPU(s) - try: - import torch - - specs["libraries.torch"] = torch.__version__ - if torch.cuda.is_available(): - specs["libraries.cuda"] = torch.version.cuda - specs["libraries.gpus"] = torch.cuda.device_count() - except ImportError: - pass - return specs diff --git a/haystack/preview/telemetry/_telemetry.py b/haystack/preview/telemetry/_telemetry.py deleted file mode 100644 index 24a0e9d9db..0000000000 --- a/haystack/preview/telemetry/_telemetry.py +++ /dev/null @@ -1,171 +0,0 @@ -from typing import Any, Dict, Optional, TYPE_CHECKING, List, Tuple -import os -from pathlib import Path -from collections import defaultdict -import datetime -import logging -import uuid -import yaml -import posthog - -from haystack.preview.telemetry._environment import collect_system_specs - -if TYPE_CHECKING: - from haystack.preview.pipeline import Pipeline - - -HAYSTACK_TELEMETRY_ENABLED = "HAYSTACK_TELEMETRY_ENABLED" -CONFIG_PATH = Path("~/.haystack/config.yaml").expanduser() - -#: Telemetry sends at most one event every number of seconds specified in this constant -MIN_SECONDS_BETWEEN_EVENTS = 60 - - -logger = logging.getLogger(__name__) - - -class Telemetry: - """ - Haystack reports anonymous usage statistics to support continuous software improvements for all its users. - - You can opt-out of sharing usage statistics by manually setting the environment - variable `HAYSTACK_TELEMETRY_ENABLED` as described for different operating systems on the - [documentation page](https://docs.haystack.deepset.ai/docs/telemetry#how-can-i-opt-out). - - Check out the documentation for more details: [Telemetry](https://docs.haystack.deepset.ai/docs/telemetry). - """ - - def __init__(self): - """ - Initializes the telemetry. Loads the user_id from the config file, - or creates a new id and saves it if the file is not found. - - It also collects system information which cannot change across the lifecycle - of the process (for example `is_containerized()`). - """ - - # disable posthog logging - for module_name in ["posthog", "backoff"]: - logging.getLogger(module_name).setLevel(logging.CRITICAL) - # Prevent module from sending errors to stderr when an exception is encountered during an emit() call - logging.getLogger(module_name).addHandler(logging.NullHandler()) - logging.getLogger(module_name).propagate = False - - self.user_id = None - - if CONFIG_PATH.exists(): - # Load the config file - try: - with open(CONFIG_PATH, "r", encoding="utf-8") as config_file: - config = yaml.safe_load(config_file) - if "user_id" in config: - self.user_id = config["user_id"] - except Exception as e: - logger.debug("Telemetry could not read the config file %s", CONFIG_PATH, exc_info=e) - else: - # Create the config file - logger.info( - "Haystack sends anonymous usage data to understand the actual usage and steer dev efforts " - "towards features that are most meaningful to users. You can opt-out at anytime by manually " - "setting the environment variable HAYSTACK_TELEMETRY_ENABLED as described for different " - "operating systems in the [documentation page](https://docs.haystack.deepset.ai/docs/telemetry#how-can-i-opt-out). " - "More information at [Telemetry](https://docs.haystack.deepset.ai/docs/telemetry)." - ) - CONFIG_PATH.parents[0].mkdir(parents=True, exist_ok=True) - self.user_id = str(uuid.uuid4()) - try: - with open(CONFIG_PATH, "w") as outfile: - yaml.dump({"user_id": self.user_id}, outfile, default_flow_style=False) - except Exception as e: - logger.debug("Telemetry could not write config file to %s", CONFIG_PATH, exc_info=e) - - self.event_properties = collect_system_specs() - - def send_event(self, event_name: str, event_properties: Optional[Dict[str, Any]] = None): - """ - Sends a telemetry event. - - :param event_name: The name of the event to show in PostHog. - :param event_properties: Additional event metadata. These are merged with the - system metadata collected in __init__, so take care not to overwrite them. - """ - event_properties = event_properties or {} - try: - posthog.capture( - distinct_id=self.user_id, event=event_name, properties={**self.event_properties, **event_properties} - ) - except Exception as e: - logger.debug("Telemetry couldn't make a POST request to PostHog.", exc_info=e) - - -def send_telemetry(func): - """ - Decorator that sends the output of the wrapped function to PostHog. - The wrapped function is actually called only if telemetry is enabled. - """ - - # FIXME? Somehow, functools.wraps makes `telemetry` out of scope. Let's take care of it later. - def send_telemetry_wrapper(*args, **kwargs): - try: - if telemetry: - output = func(*args, **kwargs) - if output: - telemetry.send_event(*output) - except Exception as e: - # Never let telemetry break things - logger.debug("There was an issue sending a telemetry event", exc_info=e) - - return send_telemetry_wrapper - - -@send_telemetry -def pipeline_running(pipeline: "Pipeline") -> Optional[Tuple[str, Dict[str, Any]]]: - """ - Collects name, type and the content of the _telemetry_data attribute, if present, for each component in the - pipeline and sends such data to Posthog. - - :param pipeline: the pipeline that is running. - """ - pipeline._telemetry_runs += 1 - if ( - pipeline._last_telemetry_sent - and (datetime.datetime.now() - pipeline._last_telemetry_sent).seconds < MIN_SECONDS_BETWEEN_EVENTS - ): - return None - - pipeline._last_telemetry_sent = datetime.datetime.now() - - # Collect info about components - pipeline_description = pipeline.to_dict() - components: Dict[str, List[Dict[str, Any]]] = defaultdict(list) - for component_name, component in pipeline_description["components"].items(): - instance = pipeline.get_component(component_name) - if hasattr(instance, "_get_telemetry_data"): - telemetry_data = getattr(instance, "_get_telemetry_data")() - try: - components[component["type"]].append({"name": component_name, **telemetry_data}) - except TypeError: - components[component["type"]].append({"name": component_name}) - else: - components[component["type"]].append({"name": component_name}) - - # Data sent to Posthog - return "Pipeline run (2.x)", { - "pipeline_id": str(id(pipeline)), - "runs": pipeline._telemetry_runs, - "components": components, - } - - -@send_telemetry -def tutorial_running(tutorial_id: str) -> Tuple[str, Dict[str, Any]]: - """ - Send a telemetry event for a tutorial, if telemetry is enabled. - :param tutorial_id: identifier of the tutorial - """ - return "Tutorial", {"tutorial.id": tutorial_id} - - -telemetry = None -if os.getenv("HAYSTACK_TELEMETRY_ENABLED", "true").lower() in ("true", "1"): - telemetry = Telemetry() diff --git a/haystack/preview/testing/__init__.py b/haystack/preview/testing/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/haystack/preview/testing/document_store.py b/haystack/preview/testing/document_store.py deleted file mode 100644 index bb8c1f15fb..0000000000 --- a/haystack/preview/testing/document_store.py +++ /dev/null @@ -1,866 +0,0 @@ -# pylint: disable=too-many-public-methods -from typing import List -import random - -import pytest -import pandas as pd - -from haystack.preview.dataclasses import Document -from haystack.preview.document_stores import DocumentStore, DuplicatePolicy -from haystack.preview.document_stores.errors import DuplicateDocumentError -from haystack.preview.errors import FilterError - - -def _random_embeddings(n): - return [random.random() for _ in range(n)] - - -# These are random embedding that are used to test filters. -# We declare them here as they're used both in the `filterable_docs` fixture -# and the body of several `filter_documents` tests. -TEST_EMBEDDING_1 = _random_embeddings(768) -TEST_EMBEDDING_2 = _random_embeddings(768) - - -class CountDocumentsTest: - """ - Utility class to test a Document Store `count_documents` method. - - To use it create a custom test class and override the `document_store` fixture to return your Document Store. - Example usage: - - ```python - class MyDocumentStoreTest(CountDocumentsTest): - @pytest.fixture - def document_store(self): - return MyDocumentStore() - ``` - """ - - @pytest.mark.unit - def test_count_empty(self, document_store: DocumentStore): - assert document_store.count_documents() == 0 - - @pytest.mark.unit - def test_count_not_empty(self, document_store: DocumentStore): - document_store.write_documents( - [Document(content="test doc 1"), Document(content="test doc 2"), Document(content="test doc 3")] - ) - assert document_store.count_documents() == 3 - - -class WriteDocumentsTest: - """ - Utility class to test a Document Store `write_documents` method. - - To use it create a custom test class and override the `document_store` fixture to return your Document Store. - The Document Store `filter_documents` method must be at least partly implemented to return all stored Documents - for this tests to work correctly. - Example usage: - - ```python - class MyDocumentStoreTest(WriteDocumentsTest): - @pytest.fixture - def document_store(self): - return MyDocumentStore() - ``` - """ - - @pytest.mark.unit - def test_write_documents(self, document_store: DocumentStore): - """ - Test write_documents() normal behaviour. - """ - doc = Document(content="test doc") - assert document_store.write_documents([doc]) == 1 - assert document_store.filter_documents() == [doc] - - @pytest.mark.unit - def test_write_documents_duplicate_fail(self, document_store: DocumentStore): - """ - Test write_documents() fails when trying to write Document with same id - using DuplicatePolicy.FAIL. - """ - doc = Document(content="test doc") - assert document_store.write_documents([doc]) == 1 - with pytest.raises(DuplicateDocumentError): - document_store.write_documents(documents=[doc], policy=DuplicatePolicy.FAIL) - assert document_store.filter_documents() == [doc] - - @pytest.mark.unit - def test_write_documents_duplicate_skip(self, document_store: DocumentStore): - """ - Test write_documents() skips Document when trying to write one with same id - using DuplicatePolicy.SKIP. - """ - doc = Document(content="test doc") - assert document_store.write_documents([doc]) == 1 - assert document_store.write_documents(documents=[doc], policy=DuplicatePolicy.SKIP) == 0 - - @pytest.mark.unit - def test_write_documents_duplicate_overwrite(self, document_store: DocumentStore): - """ - Test write_documents() overwrites stored Document when trying to write one with same id - using DuplicatePolicy.OVERWRITE. - """ - doc1 = Document(id="1", content="test doc 1") - doc2 = Document(id="1", content="test doc 2") - - assert document_store.write_documents([doc2]) == 1 - assert document_store.filter_documents() == [doc2] - assert document_store.write_documents(documents=[doc1], policy=DuplicatePolicy.OVERWRITE) == 1 - assert document_store.filter_documents() == [doc1] - - @pytest.mark.unit - def test_write_documents_invalid_input(self, document_store: DocumentStore): - """ - Test write_documents() fails when providing unexpected input. - """ - with pytest.raises(ValueError): - document_store.write_documents(["not a document for sure"]) # type: ignore - with pytest.raises(ValueError): - document_store.write_documents("not a list actually") # type: ignore - - -class DeleteDocumentsTest: - """ - Utility class to test a Document Store `delete_documents` method. - - To use it create a custom test class and override the `document_store` fixture to return your Document Store. - The Document Store `write_documents` and `count_documents` methods must be implemented for this tests to work correctly. - Example usage: - - ```python - class MyDocumentStoreTest(DeleteDocumentsTest): - @pytest.fixture - def document_store(self): - return MyDocumentStore() - ``` - """ - - @pytest.mark.unit - def test_delete_documents(self, document_store: DocumentStore): - """ - Test delete_documents() normal behaviour. - """ - doc = Document(content="test doc") - document_store.write_documents([doc]) - assert document_store.count_documents() == 1 - - document_store.delete_documents([doc.id]) - assert document_store.count_documents() == 0 - - @pytest.mark.unit - def test_delete_documents_empty_document_store(self, document_store: DocumentStore): - """ - Test delete_documents() doesn't fail when called using an empty Document Store. - """ - document_store.delete_documents(["non_existing_id"]) - - @pytest.mark.unit - def test_delete_documents_non_existing_document(self, document_store: DocumentStore): - """ - Test delete_documents() doesn't delete any Document when called with non existing id. - """ - doc = Document(content="test doc") - document_store.write_documents([doc]) - assert document_store.count_documents() == 1 - - document_store.delete_documents(["non_existing_id"]) - - # No Document has been deleted - assert document_store.count_documents() == 1 - - -class FilterableDocsFixtureMixin: - """ - Mixin class that adds a filterable_docs() fixture to a test class. - """ - - @pytest.fixture - def filterable_docs(self) -> List[Document]: - documents = [] - for i in range(3): - documents.append( - Document( - content=f"A Foo Document {i}", - meta={"name": f"name_{i}", "page": "100", "chapter": "intro", "number": 2}, - embedding=_random_embeddings(768), - ) - ) - documents.append( - Document( - content=f"A Bar Document {i}", - meta={"name": f"name_{i}", "page": "123", "chapter": "abstract", "number": -2}, - embedding=_random_embeddings(768), - ) - ) - documents.append( - Document( - content=f"A Foobar Document {i}", - meta={"name": f"name_{i}", "page": "90", "chapter": "conclusion", "number": -10}, - embedding=_random_embeddings(768), - ) - ) - documents.append( - Document( - content=f"Document {i} without embedding", - meta={"name": f"name_{i}", "no_embedding": True, "chapter": "conclusion"}, - ) - ) - documents.append(Document(dataframe=pd.DataFrame([i]), meta={"name": f"table_doc_{i}"})) - documents.append( - Document(content=f"Doc {i} with zeros emb", meta={"name": "zeros_doc"}, embedding=TEST_EMBEDDING_1) - ) - documents.append( - Document(content=f"Doc {i} with ones emb", meta={"name": "ones_doc"}, embedding=TEST_EMBEDDING_2) - ) - return documents - - -class LegacyFilterDocumentsInvalidFiltersTest(FilterableDocsFixtureMixin): - """ - Utility class to test a Document Store `filter_documents` method using invalid legacy filters - - To use it create a custom test class and override the `document_store` fixture to return your Document Store. - Example usage: - - ```python - class MyDocumentStoreTest(LegacyFilterDocumentsInvalidFiltersTest): - @pytest.fixture - def document_store(self): - return MyDocumentStore() - ``` - """ - - @pytest.mark.unit - def test_incorrect_filter_type(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - with pytest.raises(ValueError): - document_store.filter_documents(filters="something odd") # type: ignore - - @pytest.mark.unit - def test_incorrect_filter_nesting(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - with pytest.raises(FilterError): - document_store.filter_documents(filters={"number": {"page": "100"}}) - - @pytest.mark.unit - def test_deeper_incorrect_filter_nesting(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - with pytest.raises(FilterError): - document_store.filter_documents(filters={"number": {"page": {"chapter": "intro"}}}) - - -class LegacyFilterDocumentsEqualTest(FilterableDocsFixtureMixin): - """ - Utility class to test a Document Store `filter_documents` method using implicit and explicit '$eq' legacy filters - - To use it create a custom test class and override the `document_store` fixture to return your Document Store. - Example usage: - - ```python - class MyDocumentStoreTest(LegacyFilterDocumentsEqualTest): - @pytest.fixture - def document_store(self): - return MyDocumentStore() - ``` - """ - - @pytest.mark.unit - def test_filter_document_content(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"content": "A Foo Document 1"}) - assert result == [doc for doc in filterable_docs if doc.content == "A Foo Document 1"] - - @pytest.mark.unit - def test_filter_simple_metadata_value(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"page": "100"}) - assert result == [doc for doc in filterable_docs if doc.meta.get("page") == "100"] - - @pytest.mark.unit - def test_filter_document_dataframe(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"dataframe": pd.DataFrame([1])}) - assert result == [ - doc for doc in filterable_docs if doc.dataframe is not None and doc.dataframe.equals(pd.DataFrame([1])) - ] - - @pytest.mark.unit - def test_eq_filter_explicit(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"page": {"$eq": "100"}}) - assert result == [doc for doc in filterable_docs if doc.meta.get("page") == "100"] - - @pytest.mark.unit - def test_eq_filter_implicit(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"page": "100"}) - assert result == [doc for doc in filterable_docs if doc.meta.get("page") == "100"] - - @pytest.mark.unit - def test_eq_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"dataframe": pd.DataFrame([1])}) - assert result == [ - doc - for doc in filterable_docs - if isinstance(doc.dataframe, pd.DataFrame) and doc.dataframe.equals(pd.DataFrame([1])) - ] - - @pytest.mark.unit - def test_eq_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - embedding = [0.0] * 768 - result = document_store.filter_documents(filters={"embedding": embedding}) - assert result == [doc for doc in filterable_docs if embedding == doc.embedding] - - -class LegacyFilterDocumentsNotEqualTest(FilterableDocsFixtureMixin): - """ - Utility class to test a Document Store `filter_documents` method using explicit '$ne' legacy filters - - To use it create a custom test class and override the `document_store` fixture to return your Document Store. - Example usage: - - ```python - class MyDocumentStoreTest(LegacyFilterDocumentsNotEqualTest): - @pytest.fixture - def document_store(self): - return MyDocumentStore() - ``` - """ - - @pytest.mark.unit - def test_ne_filter(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"page": {"$ne": "100"}}) - assert result == [doc for doc in filterable_docs if doc.meta.get("page") != "100"] - - @pytest.mark.unit - def test_ne_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"dataframe": {"$ne": pd.DataFrame([1])}}) - assert result == [ - doc - for doc in filterable_docs - if not isinstance(doc.dataframe, pd.DataFrame) or not doc.dataframe.equals(pd.DataFrame([1])) - ] - - @pytest.mark.unit - def test_ne_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"embedding": {"$ne": TEST_EMBEDDING_1}}) - assert result == [doc for doc in filterable_docs if doc.embedding != TEST_EMBEDDING_1] - - -class LegacyFilterDocumentsInTest(FilterableDocsFixtureMixin): - """ - Utility class to test a Document Store `filter_documents` method using implicit and explicit '$in' legacy filters - - To use it create a custom test class and override the `document_store` fixture to return your Document Store. - Example usage: - - ```python - class MyDocumentStoreTest(LegacyFilterDocumentsInTest): - @pytest.fixture - def document_store(self): - return MyDocumentStore() - ``` - """ - - @pytest.mark.unit - def test_filter_simple_list_single_element(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"page": ["100"]}) - assert result == [doc for doc in filterable_docs if doc.meta.get("page") == "100"] - - @pytest.mark.unit - def test_filter_simple_list_one_value(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"page": ["100"]}) - assert result == [doc for doc in filterable_docs if doc.meta.get("page") in ["100"]] - - @pytest.mark.unit - def test_filter_simple_list(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"page": ["100", "123"]}) - assert result == [doc for doc in filterable_docs if doc.meta.get("page") in ["100", "123"]] - - @pytest.mark.unit - def test_incorrect_filter_name(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"non_existing_meta_field": ["whatever"]}) - assert len(result) == 0 - - @pytest.mark.unit - def test_incorrect_filter_value(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"page": ["nope"]}) - assert len(result) == 0 - - @pytest.mark.unit - def test_in_filter_explicit(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"page": {"$in": ["100", "123", "n.a."]}}) - assert result == [doc for doc in filterable_docs if doc.meta.get("page") in ["100", "123"]] - - @pytest.mark.unit - def test_in_filter_implicit(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"page": ["100", "123", "n.a."]}) - assert result == [doc for doc in filterable_docs if doc.meta.get("page") in ["100", "123"]] - - @pytest.mark.unit - def test_in_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"dataframe": {"$in": [pd.DataFrame([1]), pd.DataFrame([2])]}}) - assert result == [ - doc - for doc in filterable_docs - if isinstance(doc.dataframe, pd.DataFrame) - and (doc.dataframe.equals(pd.DataFrame([1])) or doc.dataframe.equals(pd.DataFrame([2]))) - ] - - @pytest.mark.unit - def test_in_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - embedding_zero = [0.0] * 768 - embedding_one = [1.0] * 768 - result = document_store.filter_documents(filters={"embedding": {"$in": [embedding_zero, embedding_one]}}) - assert result == [ - doc for doc in filterable_docs if (embedding_zero == doc.embedding or embedding_one == doc.embedding) - ] - - -class LegacyFilterDocumentsNotInTest(FilterableDocsFixtureMixin): - """ - Utility class to test a Document Store `filter_documents` method using explicit '$nin' legacy filters - - To use it create a custom test class and override the `document_store` fixture to return your Document Store. - Example usage: - - ```python - class MyDocumentStoreTest(LegacyFilterDocumentsNotInTest): - @pytest.fixture - def document_store(self): - return MyDocumentStore() - ``` - """ - - @pytest.mark.unit - def test_nin_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents( - filters={"dataframe": {"$nin": [pd.DataFrame([1]), pd.DataFrame([0])]}} - ) - assert result == [ - doc - for doc in filterable_docs - if not isinstance(doc.dataframe, pd.DataFrame) - or (not doc.dataframe.equals(pd.DataFrame([1])) and not doc.dataframe.equals(pd.DataFrame([0]))) - ] - - @pytest.mark.unit - def test_nin_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"embedding": {"$nin": [TEST_EMBEDDING_1, TEST_EMBEDDING_2]}}) - assert result == [doc for doc in filterable_docs if doc.embedding not in [TEST_EMBEDDING_1, TEST_EMBEDDING_2]] - - @pytest.mark.unit - def test_nin_filter(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"page": {"$nin": ["100", "123", "n.a."]}}) - assert result == [doc for doc in filterable_docs if doc.meta.get("page") not in ["100", "123"]] - - -class LegacyFilterDocumentsGreaterThanTest(FilterableDocsFixtureMixin): - """ - Utility class to test a Document Store `filter_documents` method using explicit '$gt' legacy filters - - To use it create a custom test class and override the `document_store` fixture to return your Document Store. - Example usage: - - ```python - class MyDocumentStoreTest(LegacyFilterDocumentsGreaterThanTest): - @pytest.fixture - def document_store(self): - return MyDocumentStore() - ``` - """ - - @pytest.mark.unit - def test_gt_filter(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"number": {"$gt": 0.0}}) - assert result == [doc for doc in filterable_docs if "number" in doc.meta and doc.meta["number"] > 0] - - @pytest.mark.unit - def test_gt_filter_non_numeric(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - with pytest.raises(FilterError): - document_store.filter_documents(filters={"page": {"$gt": "100"}}) - - @pytest.mark.unit - def test_gt_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - with pytest.raises(FilterError): - document_store.filter_documents(filters={"dataframe": {"$gt": pd.DataFrame([[1, 2, 3], [-1, -2, -3]])}}) - - @pytest.mark.unit - def test_gt_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - with pytest.raises(FilterError): - document_store.filter_documents(filters={"embedding": {"$gt": TEST_EMBEDDING_1}}) - - -class LegacyFilterDocumentsGreaterThanEqualTest(FilterableDocsFixtureMixin): - """ - Utility class to test a Document Store `filter_documents` method using explicit '$gte' legacy filters - - To use it create a custom test class and override the `document_store` fixture to return your Document Store. - Example usage: - - ```python - class MyDocumentStoreTest(LegacyFilterDocumentsGreaterThanEqualTest): - @pytest.fixture - def document_store(self): - return MyDocumentStore() - ``` - """ - - @pytest.mark.unit - def test_gte_filter(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"number": {"$gte": -2}}) - assert result == [doc for doc in filterable_docs if "number" in doc.meta and doc.meta["number"] >= -2] - - @pytest.mark.unit - def test_gte_filter_non_numeric(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - with pytest.raises(FilterError): - document_store.filter_documents(filters={"page": {"$gte": "100"}}) - - @pytest.mark.unit - def test_gte_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - with pytest.raises(FilterError): - document_store.filter_documents(filters={"dataframe": {"$gte": pd.DataFrame([[1, 2, 3], [-1, -2, -3]])}}) - - @pytest.mark.unit - def test_gte_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - with pytest.raises(FilterError): - document_store.filter_documents(filters={"embedding": {"$gte": TEST_EMBEDDING_1}}) - - -class LegacyFilterDocumentsLessThanTest(FilterableDocsFixtureMixin): - """ - Utility class to test a Document Store `filter_documents` method using explicit '$lt' legacy filters - - To use it create a custom test class and override the `document_store` fixture to return your Document Store. - Example usage: - - ```python - class MyDocumentStoreTest(LegacyFilterDocumentsLessThanTest): - @pytest.fixture - def document_store(self): - return MyDocumentStore() - ``` - """ - - @pytest.mark.unit - def test_lt_filter(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"number": {"$lt": 0.0}}) - assert result == [ - doc for doc in filterable_docs if doc.meta.get("number") is not None and doc.meta["number"] < 0 - ] - - @pytest.mark.unit - def test_lt_filter_non_numeric(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - with pytest.raises(FilterError): - document_store.filter_documents(filters={"page": {"$lt": "100"}}) - - @pytest.mark.unit - def test_lt_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - with pytest.raises(FilterError): - document_store.filter_documents(filters={"dataframe": {"$lt": pd.DataFrame([[1, 2, 3], [-1, -2, -3]])}}) - - @pytest.mark.unit - def test_lt_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - with pytest.raises(FilterError): - document_store.filter_documents(filters={"embedding": {"$lt": TEST_EMBEDDING_2}}) - - -class LegacyFilterDocumentsLessThanEqualTest(FilterableDocsFixtureMixin): - """ - Utility class to test a Document Store `filter_documents` method using explicit '$lte' legacy filters - - To use it create a custom test class and override the `document_store` fixture to return your Document Store. - Example usage: - - ```python - class MyDocumentStoreTest(LegacyFilterDocumentsLessThanEqualTest): - @pytest.fixture - def document_store(self): - return MyDocumentStore() - ``` - """ - - @pytest.mark.unit - def test_lte_filter(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"number": {"$lte": 2.0}}) - assert result == [ - doc for doc in filterable_docs if doc.meta.get("number") is not None and doc.meta["number"] <= 2.0 - ] - - @pytest.mark.unit - def test_lte_filter_non_numeric(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - with pytest.raises(FilterError): - document_store.filter_documents(filters={"page": {"$lte": "100"}}) - - @pytest.mark.unit - def test_lte_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - with pytest.raises(FilterError): - document_store.filter_documents(filters={"dataframe": {"$lte": pd.DataFrame([[1, 2, 3], [-1, -2, -3]])}}) - - @pytest.mark.unit - def test_lte_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - with pytest.raises(FilterError): - document_store.filter_documents(filters={"embedding": {"$lte": TEST_EMBEDDING_1}}) - - -class LegacyFilterDocumentsSimpleLogicalTest(FilterableDocsFixtureMixin): - """ - Utility class to test a Document Store `filter_documents` method using logical '$and', '$or' and '$not' legacy filters - - To use it create a custom test class and override the `document_store` fixture to return your Document Store. - Example usage: - - ```python - class MyDocumentStoreTest(LegacyFilterDocumentsSimpleLogicalTest): - @pytest.fixture - def document_store(self): - return MyDocumentStore() - ``` - """ - - @pytest.mark.unit - def test_filter_simple_or(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - filters = {"$or": {"name": {"$in": ["name_0", "name_1"]}, "number": {"$lt": 1.0}}} - result = document_store.filter_documents(filters=filters) - assert result == [ - doc - for doc in filterable_docs - if (doc.meta.get("number") is not None and doc.meta["number"] < 1) - or doc.meta.get("name") in ["name_0", "name_1"] - ] - - @pytest.mark.unit - def test_filter_simple_implicit_and_with_multi_key_dict( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"number": {"$lte": 2.0, "$gte": 0.0}}) - assert result == [ - doc - for doc in filterable_docs - if "number" in doc.meta and doc.meta["number"] >= 0.0 and doc.meta["number"] <= 2.0 - ] - - @pytest.mark.unit - def test_filter_simple_explicit_and_with_list(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"number": {"$and": [{"$lte": 2}, {"$gte": 0}]}}) - assert result == [ - doc - for doc in filterable_docs - if "number" in doc.meta and doc.meta["number"] <= 2.0 and doc.meta["number"] >= 0.0 - ] - - @pytest.mark.unit - def test_filter_simple_implicit_and(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"number": {"$lte": 2.0, "$gte": 0}}) - assert result == [ - doc - for doc in filterable_docs - if "number" in doc.meta and doc.meta["number"] <= 2.0 and doc.meta["number"] >= 0.0 - ] - - -class LegacyFilterDocumentsNestedLogicalTest(FilterableDocsFixtureMixin): - """ - Utility class to test a Document Store `filter_documents` method using multiple nested logical '$and', '$or' and '$not' legacy filters - - To use it create a custom test class and override the `document_store` fixture to return your Document Store. - Example usage: - - ```python - class MyDocumentStoreTest(LegacyFilterDocumentsNestedLogicalTest): - @pytest.fixture - def document_store(self): - return MyDocumentStore() - ``` - """ - - @pytest.mark.unit - def test_filter_nested_implicit_and(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - filters_simplified = {"number": {"$lte": 2, "$gte": 0}, "name": ["name_0", "name_1"]} - result = document_store.filter_documents(filters=filters_simplified) - assert result == [ - doc - for doc in filterable_docs - if ( - "number" in doc.meta - and doc.meta["number"] <= 2 - and doc.meta["number"] >= 0 - and doc.meta.get("name") in ["name_0", "name_1"] - ) - ] - - @pytest.mark.unit - def test_filter_nested_or(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - filters = {"$or": {"name": {"$or": [{"$eq": "name_0"}, {"$eq": "name_1"}]}, "number": {"$lt": 1.0}}} - result = document_store.filter_documents(filters=filters) - assert result == [ - doc - for doc in filterable_docs - if ( - doc.meta.get("name") in ["name_0", "name_1"] - or (doc.meta.get("number") is not None and doc.meta["number"] < 1) - ) - ] - - @pytest.mark.unit - def test_filter_nested_and_or_explicit(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - filters_simplified = { - "$and": {"page": {"$eq": "123"}, "$or": {"name": {"$in": ["name_0", "name_1"]}, "number": {"$lt": 1.0}}} - } - result = document_store.filter_documents(filters=filters_simplified) - assert result == [ - doc - for doc in filterable_docs - if ( - doc.meta.get("page") in ["123"] - and (doc.meta.get("name") in ["name_0", "name_1"] or ("number" in doc.meta and doc.meta["number"] < 1)) - ) - ] - - @pytest.mark.unit - def test_filter_nested_and_or_implicit(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - filters_simplified = { - "page": {"$eq": "123"}, - "$or": {"name": {"$in": ["name_0", "name_1"]}, "number": {"$lt": 1.0}}, - } - result = document_store.filter_documents(filters=filters_simplified) - assert result == [ - doc - for doc in filterable_docs - if ( - doc.meta.get("page") in ["123"] - and (doc.meta.get("name") in ["name_0", "name_1"] or ("number" in doc.meta and doc.meta["number"] < 1)) - ) - ] - - @pytest.mark.unit - def test_filter_nested_or_and(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - filters_simplified = { - "$or": { - "number": {"$lt": 1}, - "$and": {"name": {"$in": ["name_0", "name_1"]}, "$not": {"chapter": {"$eq": "intro"}}}, - } - } - result = document_store.filter_documents(filters=filters_simplified) - assert result == [ - doc - for doc in filterable_docs - if ( - (doc.meta.get("number") is not None and doc.meta["number"] < 1) - or (doc.meta.get("name") in ["name_0", "name_1"] and (doc.meta.get("chapter") != "intro")) - ) - ] - - @pytest.mark.unit - def test_filter_nested_multiple_identical_operators_same_level( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): - document_store.write_documents(filterable_docs) - filters = { - "$or": [ - {"$and": {"name": {"$in": ["name_0", "name_1"]}, "page": "100"}}, - {"$and": {"chapter": {"$in": ["intro", "abstract"]}, "page": "123"}}, - ] - } - result = document_store.filter_documents(filters=filters) - assert result == [ - doc - for doc in filterable_docs - if ( - (doc.meta.get("name") in ["name_0", "name_1"] and doc.meta.get("page") == "100") - or (doc.meta.get("chapter") in ["intro", "abstract"] and doc.meta.get("page") == "123") - ) - ] - - -class LegacyFilterDocumentsTest( # pylint: disable=too-many-ancestors - LegacyFilterDocumentsInvalidFiltersTest, - LegacyFilterDocumentsEqualTest, - LegacyFilterDocumentsNotEqualTest, - LegacyFilterDocumentsInTest, - LegacyFilterDocumentsNotInTest, - LegacyFilterDocumentsGreaterThanTest, - LegacyFilterDocumentsGreaterThanEqualTest, - LegacyFilterDocumentsLessThanTest, - LegacyFilterDocumentsLessThanEqualTest, - LegacyFilterDocumentsSimpleLogicalTest, - LegacyFilterDocumentsNestedLogicalTest, -): - """ - Utility class to test a Document Store `filter_documents` method using different types of legacy filters - - To use it create a custom test class and override the `document_store` fixture to return your Document Store. - Example usage: - - ```python - class MyDocumentStoreTest(LegacyFilterDocumentsTest): - @pytest.fixture - def document_store(self): - return MyDocumentStore() - ``` - """ - - @pytest.mark.unit - def test_no_filter_empty(self, document_store: DocumentStore): - assert document_store.filter_documents() == [] - assert document_store.filter_documents(filters={}) == [] - - @pytest.mark.unit - def test_no_filter_not_empty(self, document_store: DocumentStore): - docs = [Document(content="test doc")] - document_store.write_documents(docs) - assert document_store.filter_documents() == docs - assert document_store.filter_documents(filters={}) == docs - - -class DocumentStoreBaseTests( - CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest, LegacyFilterDocumentsTest -): # pylint: disable=too-many-ancestors - @pytest.fixture - def document_store(self) -> DocumentStore: - raise NotImplementedError() diff --git a/haystack/preview/testing/factory.py b/haystack/preview/testing/factory.py deleted file mode 100644 index d36392bfa0..0000000000 --- a/haystack/preview/testing/factory.py +++ /dev/null @@ -1,119 +0,0 @@ -from typing import Any, Dict, Optional, Tuple, Type, List, Union - -from haystack.preview import default_to_dict, default_from_dict -from haystack.preview.dataclasses import Document -from haystack.preview.document_stores import document_store, DocumentStore, DuplicatePolicy - - -def document_store_class( - name: str, - documents: Optional[List[Document]] = None, - documents_count: Optional[int] = None, - bases: Optional[Tuple[type, ...]] = None, - extra_fields: Optional[Dict[str, Any]] = None, -) -> Type[DocumentStore]: - """ - Utility function to create a DocumentStore class with the given name and list of documents. - - If `documents` is set but `documents_count` is not, `documents_count` will be the length - of `documents`. - If both are set explicitly they don't influence each other. - - `write_documents()` and `delete_documents()` are no-op. - You can override them using `extra_fields`. - - ### Usage - - Create a DocumentStore class that returns no documents: - ```python - MyFakeStore = document_store_class("MyFakeComponent") - document_store = MyFakeStore() - assert document_store.documents_count() == 0 - assert document_store.filter_documents() == [] - ``` - - Create a DocumentStore class that returns a single document: - ```python - doc = Document(id="fake_id", text="Fake content") - MyFakeStore = document_store_class("MyFakeComponent", documents=[doc]) - document_store = MyFakeStore() - assert document_store.documents_count() == 1 - assert document_store.filter_documents() == [doc] - ``` - - Create a DocumentStore class that returns no document but returns a custom count: - ```python - MyFakeStore = document_store_class("MyFakeComponent", documents_count=100) - document_store = MyFakeStore() - assert document_store.documents_count() == 100 - assert document_store.filter_documents() == [] - ``` - - Create a DocumentStore class that returns a document and a custom count: - ```python - doc = Document(id="fake_id", text="Fake content") - MyFakeStore = document_store_class("MyFakeComponent", documents=[doc], documents_count=100) - document_store = MyFakeStore() - assert document_store.documents_count() == 100 - assert document_store.filter_documents() == [doc] - ``` - - Create a DocumentStore class with a custom base class: - ```python - MyFakeStore = document_store_class( - "MyFakeStore", - bases=(MyBaseClass,) - ) - document_store = MyFakeStore() - assert isinstance(store, MyBaseClass) - ``` - - Create a DocumentStore class with an extra field `my_field`: - ```python - MyFakeStore = document_store_class( - "MyFakeStore", - extra_fields={"my_field": 10} - ) - document_store = MyFakeStore() - assert document_store.my_field == 10 - ``` - """ - if documents is not None and documents_count is None: - documents_count = len(documents) - elif documents_count is None: - documents_count = 0 - - def count_documents(self) -> Union[int, None]: - return documents_count - - def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: - if documents is not None: - return documents - return [] - - def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL) -> None: - return - - def delete_documents(self, document_ids: List[str]) -> None: - return - - def to_dict(self) -> Dict[str, Any]: - return default_to_dict(self) - - fields = { - "count_documents": count_documents, - "filter_documents": filter_documents, - "write_documents": write_documents, - "delete_documents": delete_documents, - "to_dict": to_dict, - "from_dict": classmethod(default_from_dict), - } - - if extra_fields is not None: - fields = {**fields, **extra_fields} - - if bases is None: - bases = (object,) - - cls = type(name, bases, fields) - return document_store(cls) diff --git a/haystack/preview/testing/test_utils.py b/haystack/preview/testing/test_utils.py deleted file mode 100644 index 596feb7001..0000000000 --- a/haystack/preview/testing/test_utils.py +++ /dev/null @@ -1,34 +0,0 @@ -import os -import random -import logging -import numpy as np - - -logger = logging.getLogger(__name__) - - -def set_all_seeds(seed: int, deterministic_cudnn: bool = False) -> None: - """ - Setting multiple seeds to make runs reproducible. - - Important: Enabling `deterministic_cudnn` gives you full reproducibility with CUDA, - but might slow down your training (see https://pytorch.org/docs/stable/notes/randomness.html#cudnn) ! - - :param seed:number to use as seed - :param deterministic_cudnn: Enable for full reproducibility when using CUDA. Caution: might slow down training. - """ - random.seed(seed) - np.random.seed(seed) - os.environ["PYTHONHASHSEED"] = str(seed) - - try: - import torch - - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - if deterministic_cudnn: - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - - except (ImportError, ModuleNotFoundError) as exc: - logger.info("Could not set PyTorch seed because torch is not installed. Exception: %s", exc) diff --git a/haystack/preview/utils/__init__.py b/haystack/preview/utils/__init__.py deleted file mode 100644 index a84ea468e2..0000000000 --- a/haystack/preview/utils/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from haystack.preview.utils.expit import expit -from haystack.preview.utils.requests_utils import request_with_retry -from haystack.preview.utils.filters import document_matches_filter diff --git a/haystack/preview/utils/expit.py b/haystack/preview/utils/expit.py deleted file mode 100644 index 0aaaa563cc..0000000000 --- a/haystack/preview/utils/expit.py +++ /dev/null @@ -1,5 +0,0 @@ -import numpy as np - - -def expit(x: float) -> float: - return 1 / (1 + np.exp(-x)) diff --git a/haystack/preview/utils/filters.py b/haystack/preview/utils/filters.py deleted file mode 100644 index 35475c15db..0000000000 --- a/haystack/preview/utils/filters.py +++ /dev/null @@ -1,305 +0,0 @@ -from typing import List, Any, Union, Dict -from dataclasses import fields -from datetime import datetime - -import pandas as pd - -from haystack.preview.dataclasses import Document -from haystack.preview.errors import FilterError - - -def document_matches_filter(filters: Dict[str, Any], document: Document) -> bool: - """ - Return whether `filters` match the Document. - For a detailed specification of the filters, refer to the DocumentStore.filter_documents() protocol documentation. - """ - if "field" in filters: - return _comparison_condition(filters, document) - return _logic_condition(filters, document) - - -def _and(document: Document, conditions: List[Dict[str, Any]]) -> bool: - return all(_comparison_condition(condition, document) for condition in conditions) - - -def _or(document: Document, conditions: List[Dict[str, Any]]) -> bool: - return any(_comparison_condition(condition, document) for condition in conditions) - - -def _not(document: Document, conditions: List[Dict[str, Any]]) -> bool: - return not _and(document, conditions) - - -LOGICAL_OPERATORS = {"NOT": _not, "OR": _or, "AND": _and} - - -def _equal(document_value: Any, filter_value: Any) -> bool: - if isinstance(document_value, pd.DataFrame): - document_value = document_value.to_json() - - if isinstance(filter_value, pd.DataFrame): - filter_value = filter_value.to_json() - - return document_value == filter_value - - -def _not_equal(document_value: Any, filter_value: Any) -> bool: - return not _equal(document_value=document_value, filter_value=filter_value) - - -def _greater_than(document_value: Any, filter_value: Any) -> bool: - if document_value is None or filter_value is None: - # We can't compare None values reliably using operators '>', '>=', '<', '<=' - return False - - if isinstance(document_value, str) or isinstance(filter_value, str): - try: - document_value = datetime.fromisoformat(document_value) - filter_value = datetime.fromisoformat(filter_value) - except (ValueError, TypeError) as exc: - msg = ( - "Can't compare strings using operators '>', '>=', '<', '<='. " - "Strings are only comparable if they are ISO formatted dates." - ) - raise FilterError(msg) from exc - if type(filter_value) in [list, pd.DataFrame]: - msg = f"Filter value can't be of type {type(filter_value)} using operators '>', '>=', '<', '<='" - raise FilterError(msg) - return document_value > filter_value - - -def _greater_than_equal(document_value: Any, filter_value: Any) -> bool: - if document_value is None or filter_value is None: - # We can't compare None values reliably using operators '>', '>=', '<', '<=' - return False - - return _equal(document_value=document_value, filter_value=filter_value) or _greater_than( - document_value=document_value, filter_value=filter_value - ) - - -def _less_than(document_value: Any, filter_value: Any) -> bool: - if document_value is None or filter_value is None: - # We can't compare None values reliably using operators '>', '>=', '<', '<=' - return False - - return not _greater_than_equal(document_value=document_value, filter_value=filter_value) - - -def _less_than_equal(document_value: Any, filter_value: Any) -> bool: - if document_value is None or filter_value is None: - # We can't compare None values reliably using operators '>', '>=', '<', '<=' - return False - - return not _greater_than(document_value=document_value, filter_value=filter_value) - - -def _in(document_value: Any, filter_value: Any) -> bool: - if not isinstance(filter_value, list): - msg = ( - f"Filter value must be a `list` when using operator 'in' or 'not in', received type '{type(filter_value)}'" - ) - raise FilterError(msg) - return any(_equal(e, document_value) for e in filter_value) - - -def _not_in(document_value: Any, filter_value: Any) -> bool: - return not _in(document_value=document_value, filter_value=filter_value) - - -COMPARISON_OPERATORS = { - "==": _equal, - "!=": _not_equal, - ">": _greater_than, - ">=": _greater_than_equal, - "<": _less_than, - "<=": _less_than_equal, - "in": _in, - "not in": _not_in, -} - - -def _logic_condition(condition: Dict[str, Any], document: Document) -> bool: - if "operator" not in condition: - msg = f"'operator' key missing in {condition}" - raise FilterError(msg) - if "conditions" not in condition: - msg = f"'conditions' key missing in {condition}" - raise FilterError(msg) - operator: str = condition["operator"] - conditions: List[Dict[str, Any]] = condition["conditions"] - return LOGICAL_OPERATORS[operator](document, conditions) - - -def _comparison_condition(condition: Dict[str, Any], document: Document) -> bool: - if "field" not in condition: - # 'field' key is only found in comparison dictionaries. - # We assume this is a logic dictionary since it's not present. - return _logic_condition(condition, document) - field: str = condition["field"] - - if "operator" not in condition: - msg = f"'operator' key missing in {condition}" - raise FilterError(msg) - if "value" not in condition: - msg = f"'value' key missing in {condition}" - raise FilterError(msg) - - if "." in field: - # Handles fields formatted like so: - # 'meta.person.name' - parts = field.split(".") - document_value = getattr(document, parts[0]) - for part in parts[1:]: - if part not in document_value: - # If a field is not found we treat it as None - document_value = None - break - document_value = document_value[part] - elif field not in [f.name for f in fields(document)]: - # Converted legacy filters don't add the `meta.` prefix, so we assume - # that all filter fields that are not actual fields in Document are converted - # filters. - # - # We handle this to avoid breaking compatibility with converted legacy filters. - # This will be removed as soon as we stop supporting legacy filters. - document_value = document.meta.get(field) - else: - document_value = getattr(document, field) - operator: str = condition["operator"] - filter_value: Any = condition["value"] - return COMPARISON_OPERATORS[operator](filter_value=filter_value, document_value=document_value) - - -def convert(filters: Dict[str, Any]) -> Dict[str, Any]: - """ - Convert a filter declared using the legacy style into the new style. - This is mostly meant to ease migration from Haystack 1.x to 2.x for developers - of Document Stores and Components that use filters. - - This function doesn't verify if `filters` are declared using the legacy style. - - Example usage: - ```python - legacy_filter = { - "$and": { - "type": {"$eq": "article"}, - "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, - "rating": {"$gte": 3}, - "$or": {"genre": {"$in": ["economy", "politics"]}, "publisher": {"$eq": "nytimes"}}, - } - } - assert convert(legacy_filter) == { - "operator": "AND", - "conditions": [ - {"field": "type", "operator": "==", "value": "article"}, - {"field": "date", "operator": ">=", "value": "2015-01-01"}, - {"field": "date", "operator": "<", "value": "2021-01-01"}, - {"field": "rating", "operator": ">=", "value": 3}, - { - "operator": "OR", - "conditions": [ - {"field": "genre", "operator": "in", "value": ["economy", "politics"]}, - {"field": "publisher", "operator": "==", "value": "nytimes"}, - ], - }, - ], - } - ``` - """ - if not isinstance(filters, dict): - msg = f"Can't convert filters from type '{type(filters)}'" - raise ValueError(msg) - - converted = _internal_convert(filters) - if "conditions" not in converted: - # This is done to handle a corner case when filter is really simple like so: - # {"text": "A Foo Document 1"} - # The root '$and' operator is implicit so the conversion doesn't handle - # it and it must be added explicitly like so. - # This only happens for simple filters like the one above. - return {"operator": "AND", "conditions": [converted]} - return converted - - -def _internal_convert(filters: Union[List[Any], Dict[str, Any]], previous_key=None) -> Any: - """ - Recursively convert filters from legacy to new style. - """ - conditions = [] - - if isinstance(filters, list) and (result := _handle_list(filters, previous_key)) is not None: - return result - - if not isinstance(filters, dict): - return _handle_non_dict(filters, previous_key) - - for key, value in filters.items(): - if ( - previous_key is not None - and previous_key not in ALL_LEGACY_OPERATORS_MAPPING - and key not in ALL_LEGACY_OPERATORS_MAPPING - ): - msg = f"This filter ({filters}) seems to be malformed." - raise FilterError(msg) - if key not in ALL_LEGACY_OPERATORS_MAPPING: - converted = _internal_convert(value, previous_key=key) - if isinstance(converted, list): - conditions.extend(converted) - else: - conditions.append(converted) - elif key in LEGACY_LOGICAL_OPERATORS_MAPPING: - if previous_key not in ALL_LEGACY_OPERATORS_MAPPING and isinstance(value, list): - converted = [_internal_convert({previous_key: v}) for v in value] - conditions.append({"operator": ALL_LEGACY_OPERATORS_MAPPING[key], "conditions": converted}) - else: - converted = _internal_convert(value, previous_key=key) - if key == "$not" and type(converted) not in [dict, list]: - # This handles a corner when '$not' is used like this: - # '{"page": {"$not": 102}}' - # Without this check we would miss the implicit '$eq' - converted = {"field": previous_key, "operator": "==", "value": value} - if not isinstance(converted, list): - converted = [converted] - conditions.append({"operator": ALL_LEGACY_OPERATORS_MAPPING[key], "conditions": converted}) - elif key in LEGACY_COMPARISON_OPERATORS_MAPPING: - conditions.append({"field": previous_key, "operator": ALL_LEGACY_OPERATORS_MAPPING[key], "value": value}) - - if len(conditions) == 1: - return conditions[0] - - if previous_key is None: - return {"operator": "AND", "conditions": conditions} - - return conditions - - -def _handle_list(filters, previous_key): - if previous_key in LEGACY_LOGICAL_OPERATORS_MAPPING: - return [_internal_convert(f) for f in filters] - elif previous_key not in LEGACY_COMPARISON_OPERATORS_MAPPING: - return {"field": previous_key, "operator": "in", "value": filters} - return None - - -def _handle_non_dict(filters, previous_key): - if previous_key not in ALL_LEGACY_OPERATORS_MAPPING: - return {"field": previous_key, "operator": "==", "value": filters} - return filters - - -# Operator mappings from legacy style to new one -LEGACY_LOGICAL_OPERATORS_MAPPING = {"$and": "AND", "$or": "OR", "$not": "NOT"} - -LEGACY_COMPARISON_OPERATORS_MAPPING = { - "$eq": "==", - "$ne": "!=", - "$gt": ">", - "$gte": ">=", - "$lt": "<", - "$lte": "<=", - "$in": "in", - "$nin": "not in", -} - -ALL_LEGACY_OPERATORS_MAPPING = {**LEGACY_LOGICAL_OPERATORS_MAPPING, **LEGACY_COMPARISON_OPERATORS_MAPPING} diff --git a/haystack/preview/utils/requests_utils.py b/haystack/preview/utils/requests_utils.py deleted file mode 100644 index 245d7737fb..0000000000 --- a/haystack/preview/utils/requests_utils.py +++ /dev/null @@ -1,94 +0,0 @@ -from typing import Optional, List - -import logging - -from tenacity import retry, wait_exponential, retry_if_exception_type, stop_after_attempt, before_log, after_log -import requests - -logger = logging.getLogger(__file__) - - -def request_with_retry( - attempts: int = 3, status_codes_to_retry: Optional[List[int]] = None, **kwargs -) -> requests.Response: - """ - request_with_retry is a simple wrapper function that executes an HTTP request - with a configurable exponential backoff retry on failures. - - All kwargs will be passed to ``requests.request``, so it accepts the same arguments. - - Example Usage: - -------------- - - # Sending an HTTP request with default retry configs - res = request_with_retry(method="GET", url="https://example.com") - - # Sending an HTTP request with custom number of attempts - res = request_with_retry(method="GET", url="https://example.com", attempts=10) - - # Sending an HTTP request with custom HTTP codes to retry - res = request_with_retry(method="GET", url="https://example.com", status_codes_to_retry=[408, 503]) - - # Sending an HTTP request with custom timeout in seconds - res = request_with_retry(method="GET", url="https://example.com", timeout=5) - - # Sending an HTTP request with custom authorization handling - class CustomAuth(requests.auth.AuthBase): - def __call__(self, r): - r.headers["authorization"] = "Basic " - return r - - res = request_with_retry(method="GET", url="https://example.com", auth=CustomAuth()) - - # All of the above combined - res = request_with_retry( - method="GET", - url="https://example.com", - auth=CustomAuth(), - attempts=10, - status_codes_to_retry=[408, 503], - timeout=5 - ) - - # Sending a POST request - res = request_with_retry(method="POST", url="https://example.com", data={"key": "value"}, attempts=10) - - # Retry all 5xx status codes - res = request_with_retry(method="GET", url="https://example.com", status_codes_to_retry=list(range(500, 600))) - - :param attempts: Maximum number of attempts to retry the request, defaults to 3 - :param status_codes_to_retry: List of HTTP status codes that will trigger a retry, defaults to [408, 418, 429, 503]: - - `408: Request Timeout` - - `418` - - `429: Too Many Requests` - - `503: Service Unavailable` - :param **kwargs: Optional arguments that ``request`` takes. - :return: :class:`Response ` object - """ - - if status_codes_to_retry is None: - status_codes_to_retry = [408, 418, 429, 503] - - @retry( - reraise=True, - wait=wait_exponential(), - retry=retry_if_exception_type((requests.HTTPError, TimeoutError)), - stop=stop_after_attempt(attempts), - before=before_log(logger, logging.DEBUG), - after=after_log(logger, logging.DEBUG), - ) - def run(): - timeout = kwargs.pop("timeout", 10) - res = requests.request(**kwargs, timeout=timeout) - - if res.status_code in status_codes_to_retry: - # We raise only for the status codes that must trigger a retry - res.raise_for_status() - - return res - - res = run() - # We raise here too in case the request failed with a status code that - # won't trigger a retry, this way the call will still cause an explicit exception - res.raise_for_status() - return res diff --git a/haystack/preview/version.py b/haystack/preview/version.py deleted file mode 100644 index 23a3060671..0000000000 --- a/haystack/preview/version.py +++ /dev/null @@ -1,13 +0,0 @@ -from importlib import metadata - -# haystack.preview is distributed as a separate package called `haystack-ai`. -# We want to keep all preview dependencies separate from the current Haystack version, -# so imports in haystack.preview must only import from haystack.preview. -# Since we need to access __version__ in haystack.preview without importing from -# haystack we must set it here too. -# When installing `haystack-ai` we want to use that package version though -# as `farm-haystack` might not be installed and cause this to fail. -try: - __version__ = str(metadata.version("haystack-ai")) -except metadata.PackageNotFoundError: - __version__ = str(metadata.version("farm-haystack")) diff --git a/haystack/utils/__init__.py b/haystack/utils/__init__.py index 15a9794082..6a27f4b0f2 100644 --- a/haystack/utils/__init__.py +++ b/haystack/utils/__init__.py @@ -24,4 +24,3 @@ from haystack.utils.early_stopping import EarlyStopping from haystack.utils.labels import aggregate_labels from haystack.utils.batching import get_batches_from_generator -from haystack.utils.getting_started import build_pipeline, add_example_data diff --git a/haystack/utils/experiment_tracking.py b/haystack/utils/experiment_tracking.py index 2a9f8d1ef4..21195449d7 100644 --- a/haystack/utils/experiment_tracking.py +++ b/haystack/utils/experiment_tracking.py @@ -213,7 +213,7 @@ def track_params(self, params: Dict[str, Any]): def track_artifacts(self, dir_path: Union[str, Path], artifact_path: Optional[str] = None): try: - mlflow.log_artifacts(dir_path, artifact_path) + mlflow.log_artifacts(str(dir_path), artifact_path) except ConnectionError: logger.warning("ConnectionError in logging artifacts to MLflow") except Exception as e: diff --git a/haystack/utils/getting_started.py b/haystack/utils/getting_started.py deleted file mode 100644 index cd54e7169d..0000000000 --- a/haystack/utils/getting_started.py +++ /dev/null @@ -1,84 +0,0 @@ -import logging -import os - -from haystack.utils import convert_files_to_docs -from haystack.utils import fetch_archive_from_http - -logger = logging.getLogger(__name__) - - -def build_pipeline(provider, API_KEY, document_store): - # Importing top-level causes a circular import - from haystack.nodes import AnswerParser, PromptNode, PromptTemplate, BM25Retriever - from haystack.pipelines import Pipeline - - provider = provider.lower() - # A retriever selects the right documents when given a question. - retriever = BM25Retriever(document_store=document_store, top_k=5) - # Load prompt for doing retrieval augmented generation from https://prompthub.deepset.ai/?prompt=deepset%2Fquestion-answering-with-references - question_answering_with_references = PromptTemplate( - prompt="deepset/question-answering-with-references", - output_parser=AnswerParser(reference_pattern=r"Document\[(\d+)\]"), - ) - # Load the LLM model - if provider == "anthropic": - prompt_node = PromptNode( - model_name_or_path="claude-2", api_key=API_KEY, default_prompt_template=question_answering_with_references - ) - elif provider == "cohere": - prompt_node = PromptNode( - model_name_or_path="command", api_key=API_KEY, default_prompt_template=question_answering_with_references - ) - elif provider == "huggingface": - # TODO: swap out for meta-llama/Llama-2-7b-chat-hf or the 40b model once supported in Haystack+HF API free tier - # The tiiuae/falcon-7b-instruct model cannot handle a complex prompt with references, so we use a very simple one - simple_QA = PromptTemplate( - prompt="deepset/question-answering", output_parser=AnswerParser(reference_pattern=r"Document\[(\d+)\]") - ) - prompt_node = PromptNode( - model_name_or_path="tiiuae/falcon-7b-instruct", api_key=API_KEY, default_prompt_template=simple_QA - ) - elif provider == "openai": - prompt_node = PromptNode( - model_name_or_path="gpt-3.5-turbo-0301", - api_key=API_KEY, - default_prompt_template=question_answering_with_references, - ) - else: - logger.error('Given unknown. Please use any of "anthropic", "cohere", "huggingface", or "openai"') - # Compose the query pipeline - query_pipeline = Pipeline() - query_pipeline.add_node(component=retriever, name="retriever", inputs=["Query"]) - query_pipeline.add_node(component=prompt_node, name="prompt_node", inputs=["retriever"]) - - return query_pipeline - - -def add_example_data(document_store, dir): - # Importing top-level causes a circular import - from haystack.nodes import TextConverter, PreProcessor - - if dir == "data/GoT_getting_started": - # Download and add Game of Thrones TXT files - fetch_archive_from_http( - url="https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-qa/datasets/documents/wiki_gameofthrones_txt.zip", - output_dir=dir, - ) - files_to_index = [dir + "/" + f for f in os.listdir(dir)] - converter = TextConverter(remove_numeric_tables=True) - docs = [converter.convert(file_path=file, meta=None)[0] for file in files_to_index] - else: - # Here you can add a local folder with your files(.txt, .pdf, .docx). - # You might need to install additional packages with "pip install farm-haystack[ocr,preprocessing,file-conversion,pdf]". - # For more details, see: https://haystack.deepset.ai/tutorials/08_preprocessing. - # Be aware that some of your data will be sent to external APIs if you use this functionality! - files_to_index = [dir + "/" + f for f in os.listdir(dir)] - logger.info("Adding %s number of files from local disk at %s.", len(files_to_index), dir) - docs = convert_files_to_docs(dir_path=dir) - - preprocessor = PreProcessor( - split_by="word", split_length=200, split_overlap=0, split_respect_sentence_boundary=True - ) - docs_processed = preprocessor.process(docs) - - document_store.write_documents(documents=docs_processed) diff --git a/haystack/utils/import_utils.py b/haystack/utils/import_utils.py index a1d537ba50..6945fbf3ea 100644 --- a/haystack/utils/import_utils.py +++ b/haystack/utils/import_utils.py @@ -1,21 +1,19 @@ -import io import gzip -import tarfile -import zipfile -import logging import importlib import importlib.util +import io +import logging +import zipfile +from os.path import basename, splitext from pathlib import Path -from typing import Optional, Dict, Union, Tuple, List -from urllib.parse import urlparse, unquote -from os.path import splitext, basename +from typing import Dict, List, Optional, Tuple, Union +from urllib.parse import unquote, urlparse import requests from haystack.errors import DatasetsError from haystack.schema import Document - logger = logging.getLogger(__name__) @@ -69,8 +67,7 @@ def fetch_archive_from_http( timeout: Union[float, Tuple[float, float]] = 10.0, ) -> bool: """ - Fetch an archive (zip, gz or tar.gz) from a url via http and extract content to an output directory. - + Fetch an archive (zip or gz) from a url via http and extract content to an output directory. :param url: http address :param output_dir: local path :param proxies: proxies details as required by requests library @@ -102,9 +99,6 @@ def fetch_archive_from_http( file_content = gzip_archive.read() with open(f"{output_dir}/{file_name}", "wb") as file: file.write(file_content) - elif archive_extension in ["gz", "bz2", "xz"]: - tar_archive = tarfile.open(fileobj=io.BytesIO(request_data.content), mode="r|*") - tar_archive.extractall(output_dir) else: logger.warning( "Skipped url %s as file type is not supported here. " diff --git a/haystack/utils/openai_utils.py b/haystack/utils/openai_utils.py index 25de13e508..450b63a0ed 100644 --- a/haystack/utils/openai_utils.py +++ b/haystack/utils/openai_utils.py @@ -64,39 +64,37 @@ def _openai_text_completion_tokenization_details(model_name: str): :param model_name: Name of the OpenAI model. """ - tokenizer_name = "gpt2" - max_tokens_limit = 2049 # Based on this ref: https://platform.openai.com/docs/models/gpt-3 + tokenizer_name = "cl100k_base" + # It is the minimum max_tokens_limit value based on this ref: https://platform.openai.com/docs/models/overview + max_tokens_limit = 4096 try: model_tokenizer = tiktoken.encoding_name_for_model(model_name) except KeyError: model_tokenizer = None if model_tokenizer: - # Based on OpenAI models page, 'davinci' considers have 2049 tokens, - ## therefore, it is better to add `text-davinci` instead to the condition. - ## Ref: https://platform.openai.com/docs/models/gpt-3-5 - ## https://platform.openai.com/docs/models/gpt-3 - if "text-davinci" in model_name: - max_tokens_limit = 4097 - tokenizer_name = model_tokenizer - elif model_name.startswith("gpt-3.5-turbo-16k") or model_name.startswith("gpt-35-turbo-16k"): + tokenizer_name = model_tokenizer + if model_name == "davinci-002" or model_name == "babbage-002": max_tokens_limit = 16384 - tokenizer_name = model_tokenizer - elif model_name.startswith("gpt-3"): - max_tokens_limit = 4096 - tokenizer_name = model_tokenizer + + if model_name.startswith("gpt-3.5-turbo") or model_name.startswith("gpt-35-turbo"): + max_tokens_limit = 16385 + # Handles edge-cases where the value is 4096 + if ( + model_name == "gpt-3.5-turbo-instruct" + or model_name == "gpt-3.5-turbo-0613" + or model_name == "gpt-35-turbo-instruct" + or model_name == "gpt-35-turbo-0613" + ): + max_tokens_limit = 4096 + # Ref: https://platform.openai.com/docs/models/gpt-4 - elif model_name.startswith("gpt-4-32k"): - max_tokens_limit = 32768 # tokens - tokenizer_name = model_tokenizer - elif model_name.startswith("gpt-4-1106-preview"): - max_tokens_limit = 128000 # tokens - tokenizer_name = model_tokenizer - elif model_name.startswith("gpt-4"): - max_tokens_limit = 8192 # tokens - tokenizer_name = model_tokenizer - else: - tokenizer_name = model_tokenizer + if model_name.startswith("gpt-4"): + max_tokens_limit = 128000 + if model_name == "gpt-4" or model_name == "gpt-4-0613": + max_tokens_limit = 8192 + if model_name == "gpt-4-32k" or model_name == "gpt-4-32k-0613": + max_tokens_limit = 32768 return tokenizer_name, max_tokens_limit diff --git a/haystack/utils/torch_utils.py b/haystack/utils/torch_utils.py index 01aaba8023..ce0343466f 100644 --- a/haystack/utils/torch_utils.py +++ b/haystack/utils/torch_utils.py @@ -52,3 +52,25 @@ def get_devices(devices: Optional[List[Union[str, torch.device]]]) -> List[torch ): return [torch.device("mps")] return [torch.device("cpu")] + + +def resolve_torch_dtype(torch_dtype: Optional[Union[str, "torch.dtype"]]) -> Optional["torch.dtype"]: + """ + Extract the torch dtype specified in kwargs. This function ensures the returned dtype is of a `torch.dtype` type. + """ + torch_dtype_resolved = None + if torch_dtype is not None: + if isinstance(torch_dtype, str): + if "torch." in torch_dtype: + torch_dtype_resolved = getattr(torch, torch_dtype.strip("torch.")) + elif torch_dtype == "auto": + torch_dtype_resolved = torch_dtype + else: + raise ValueError( + f"torch_dtype should be a torch.dtype, a string with 'torch.' prefix or the string 'auto', got {torch_dtype}" + ) + elif isinstance(torch_dtype, torch.dtype): + torch_dtype_resolved = torch_dtype + else: + raise ValueError(f"Invalid torch_dtype value {torch_dtype}") + return torch_dtype_resolved diff --git a/pyproject.toml b/pyproject.toml index 8def0a45c3..136d9c8205 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,7 @@ dependencies = [ "requests", "httpx", "pydantic<2", - "transformers==4.35.2", + "transformers>=4.46,<5.0", "pandas", "rank_bm25", "scikit-learn>=1.3.0", # TF-IDF and metrics @@ -85,24 +85,9 @@ dependencies = [ ] [project.optional-dependencies] -preview = [ - "canals==0.10.1", - "requests", - "pandas", - "rank_bm25", - "tqdm", - "tenacity", - "lazy-imports", - "posthog", # telemetry - - "Jinja2", - "openai<1.0.0", - "pyyaml", - "more-itertools", # DocumentSplitter -] inference = [ - "transformers[torch,sentencepiece]==4.35.2", - "sentence-transformers>=2.2.0", # See haystack/nodes/retriever/_embedding_encoder.py, _SentenceTransformersEmbeddingEncoder + "transformers[torch,sentencepiece]>=4.46,<5.0", + "sentence-transformers<=3.0.0,>=2.3.1", # See haystack/nodes/retriever/_embedding_encoder.py, _SentenceTransformersEmbeddingEncoder "huggingface-hub>=0.5.0", ] elasticsearch = [ @@ -145,11 +130,14 @@ pinecone = [ opensearch = [ "opensearch-py>=2", ] +mongodb = [ + "pymongo>=4.6", +] docstores = [ - "farm-haystack[elasticsearch,faiss,weaviate,pinecone,opensearch]", + "farm-haystack[elasticsearch,faiss,weaviate,pinecone,opensearch,mongodb]", ] docstores-gpu = [ - "farm-haystack[elasticsearch,faiss-gpu,weaviate,pinecone,opensearch]", + "farm-haystack[elasticsearch,faiss-gpu,weaviate,pinecone,opensearch,mongodb]", ] aws = [ # first version to support Amazon Bedrock @@ -159,13 +147,13 @@ crawler = [ "selenium>=4.11.0" ] preprocessing = [ - "nltk", + "nltk>=3.9.1", "langdetect", # for language classification ] file-conversion = [ "azure-ai-formrecognizer>=3.2.0b2", # Microsoft Azure's Form Recognizer service (text and table exctrator) "python-docx", - "python-pptx", + "python-pptx<=1.0", "tika", # Apache Tika (text & metadata extractor) "beautifulsoup4", "markdown", @@ -173,9 +161,7 @@ file-conversion = [ "python-magic; platform_system != 'Windows'", # Depends on libmagic: https://pypi.org/project/python-magic/ "python-magic-bin; platform_system == 'Windows'", # Needs to be installed without python-magic, otherwise Windows CI gets stuck. ] -pdf = [ - "PyMuPDF>=1.18.16" , # PDF text extraction alternative to xpdf; please check AGPLv3 license -] +pdf = [] ocr = [ "pytesseract>0.3.7", "pdf2image>1.14", @@ -205,7 +191,7 @@ colab = [ dev = [ "pre-commit", # Type check - "mypy", + "mypy==1.10.0", # Test "pytest", "pytest-cov", @@ -239,11 +225,11 @@ audio = [ ] all = [ - "farm-haystack[inference,docstores,crawler,preprocessing,file-conversion,pdf,ocr,metrics,aws,preview,audio]", + "farm-haystack[inference,docstores,crawler,preprocessing,file-conversion,pdf,ocr,metrics,aws,audio]", ] all-gpu = [ # beir is incompatible with faiss-gpu: https://github.com/beir-cellar/beir/issues/71 - "farm-haystack[inference,docstores-gpu,crawler,preprocessing,file-conversion,pdf,ocr,metrics,aws,preview,audio]", + "farm-haystack[inference,docstores-gpu,crawler,preprocessing,file-conversion,pdf,ocr,metrics,aws,audio]", ] [project.scripts] @@ -425,7 +411,6 @@ max-complexity = 28 [tool.ruff.per-file-ignores] "examples/basic_qa_pipeline.py" = ["C416"] -"haystack/preview/testing/document_store.py" = ["C416", "F821"] "haystack/telemetry.py" = ["F821"] [tool.ruff.pylint] diff --git a/releasenotes/config.yaml b/releasenotes/config.yaml index 4389c77c25..765b1b2d7c 100644 --- a/releasenotes/config.yaml +++ b/releasenotes/config.yaml @@ -1,7 +1,6 @@ default_branch: main collapse_pre_releases: true pre_release_tag_re: (?P-(?:[ab]|rc)+\d*)$ -prelude_section_name: ⭐ Highlights template: | --- prelude: > @@ -36,10 +35,7 @@ template: | fixes: - | Add normal bug fixes here, or remove this section. - preview: - - | - Add changes to Haystack version 2, or remove this section. - Haystack version 2 can be found under haystack/preview. + sections: # The prelude section is implicitly included. - [upgrade, ⬆️ Upgrade Notes] @@ -49,4 +45,3 @@ sections: - [deprecations, ⚠️ Deprecation Notes] - [security, Security Notes] - [fixes, 🐛 Bug Fixes] - - [preview, 🩵 Haystack 2.0 preview] diff --git a/releasenotes/notes/126-hihglights-f9464f4c40258d02.yaml b/releasenotes/notes/126-hihglights-f9464f4c40258d02.yaml new file mode 100644 index 0000000000..5006b3aed5 --- /dev/null +++ b/releasenotes/notes/126-hihglights-f9464f4c40258d02.yaml @@ -0,0 +1,6 @@ +--- +prelude: > + We are announcing that Haystack 1.26 is the final minor release for Haystack 1.x. + Although we will continue to release bug fixes for this version, we will neither + be adding nor removing any functionalities. Instead, we will focus our efforts on + Haystack 2.x. Haystack 1.26 will reach its end-of-life on March 11, 2025. diff --git a/releasenotes/notes/add-arg-to-PineconeDocumentStore-984add063663e70b.yaml b/releasenotes/notes/add-arg-to-PineconeDocumentStore-984add063663e70b.yaml new file mode 100644 index 0000000000..75629e23f7 --- /dev/null +++ b/releasenotes/notes/add-arg-to-PineconeDocumentStore-984add063663e70b.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Users can now define the number of pods and pod type directly when creating a PineconeDocumentStore instance. diff --git a/releasenotes/notes/add-dead-end-for-classifier-8c87716695efd86a.yaml b/releasenotes/notes/add-dead-end-for-classifier-8c87716695efd86a.yaml new file mode 100644 index 0000000000..23a48174c2 --- /dev/null +++ b/releasenotes/notes/add-dead-end-for-classifier-8c87716695efd86a.yaml @@ -0,0 +1,5 @@ +--- +fixes: + - | + Pipeline run error when using the FileTypeClassifier with the raise_on_error: True option. + Instead of returning an unexpected NoneType, we route the file to a dead-end edge. diff --git a/releasenotes/notes/add-mongodb-document-store-34bd05d03717fb62.yaml b/releasenotes/notes/add-mongodb-document-store-34bd05d03717fb62.yaml new file mode 100644 index 0000000000..238aa5d23d --- /dev/null +++ b/releasenotes/notes/add-mongodb-document-store-34bd05d03717fb62.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Add `MongoDBAtlasDocumentStore`, providing support for MongoDB Atlas as a document store. diff --git a/releasenotes/notes/add-raise-on-failure-to-base-converter-8c5e9b3dd51c0e0c.yaml b/releasenotes/notes/add-raise-on-failure-to-base-converter-8c5e9b3dd51c0e0c.yaml new file mode 100644 index 0000000000..05e4a959d7 --- /dev/null +++ b/releasenotes/notes/add-raise-on-failure-to-base-converter-8c5e9b3dd51c0e0c.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Add `raise_on_failure` flag to BaseConverter class so that big processes can optionally continue without breaking from exceptions. diff --git a/releasenotes/notes/aws-bedrock-embedding-encoder-a978884c1a2c8237.yaml b/releasenotes/notes/aws-bedrock-embedding-encoder-a978884c1a2c8237.yaml new file mode 100644 index 0000000000..caa1f479b8 --- /dev/null +++ b/releasenotes/notes/aws-bedrock-embedding-encoder-a978884c1a2c8237.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Adding Bedrock Embeddings Encoder to use as a retriever. diff --git a/releasenotes/notes/bedrock-support-for-llama3-a1c2c4fcfb5a8395.yaml b/releasenotes/notes/bedrock-support-for-llama3-a1c2c4fcfb5a8395.yaml new file mode 100644 index 0000000000..e9d7d36194 --- /dev/null +++ b/releasenotes/notes/bedrock-support-for-llama3-a1c2c4fcfb5a8395.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Support for Llama3 models on AWS Bedrock. diff --git a/releasenotes/notes/bedrock-support-for-mistralai-and-new-claude-3-models-f141aef0a7690ef3.yaml b/releasenotes/notes/bedrock-support-for-mistralai-and-new-claude-3-models-f141aef0a7690ef3.yaml new file mode 100644 index 0000000000..ca954843ba --- /dev/null +++ b/releasenotes/notes/bedrock-support-for-mistralai-and-new-claude-3-models-f141aef0a7690ef3.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Support for MistralAI and new Claude 3 models on AWS Bedrock. diff --git a/releasenotes/notes/bump-transformers-4-37-2-d265006022c21671.yaml b/releasenotes/notes/bump-transformers-4-37-2-d265006022c21671.yaml new file mode 100644 index 0000000000..d939c98ffa --- /dev/null +++ b/releasenotes/notes/bump-transformers-4-37-2-d265006022c21671.yaml @@ -0,0 +1,5 @@ +--- +enhancements: + - | + Upgrade Transformers to the latest version 4.37.2. + This version adds support for the Phi-2 and Qwen2 models and improves support for quantization. diff --git a/releasenotes/notes/bump-transformers-to-4-39-2de45528eff50ce8.yaml b/releasenotes/notes/bump-transformers-to-4-39-2de45528eff50ce8.yaml new file mode 100644 index 0000000000..ec74b501d5 --- /dev/null +++ b/releasenotes/notes/bump-transformers-to-4-39-2de45528eff50ce8.yaml @@ -0,0 +1,3 @@ +--- +enhancements: + - Upgrade transformers to version 4.39.3 so that Haystack can support the new Cohere Command R models. diff --git a/releasenotes/notes/correct-crawler-pdf-download-location-82a443be7b07e182.yaml b/releasenotes/notes/correct-crawler-pdf-download-location-82a443be7b07e182.yaml new file mode 100644 index 0000000000..63a70fcfbc --- /dev/null +++ b/releasenotes/notes/correct-crawler-pdf-download-location-82a443be7b07e182.yaml @@ -0,0 +1,5 @@ +--- +fixes: + - | + Ensure that the crawled files are downloaded to the `output_dir` directory, as specified in the `Crawler` + constructor. Previously, some files were incorrectly downloaded to the current working directory. diff --git a/releasenotes/notes/crawler-webdriver-di-71225500f3751983.yaml b/releasenotes/notes/crawler-webdriver-di-71225500f3751983.yaml new file mode 100644 index 0000000000..c379875013 --- /dev/null +++ b/releasenotes/notes/crawler-webdriver-di-71225500f3751983.yaml @@ -0,0 +1,6 @@ +--- +features: + - | + Add an optional `webdriver` parameter to Crawler. + This allows using a pre-configured custom webdriver instead of creating + the default Chrome webdriver. diff --git a/releasenotes/notes/farmreader-answer-page-number-5b4a3c70e03b3580.yaml b/releasenotes/notes/farmreader-answer-page-number-5b4a3c70e03b3580.yaml new file mode 100644 index 0000000000..2b8c1e88ec --- /dev/null +++ b/releasenotes/notes/farmreader-answer-page-number-5b4a3c70e03b3580.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Correctly calculate the answer page number for Extractive Answers diff --git a/releasenotes/notes/farmreader-kwargs-130e0b1a89f75fef.yaml b/releasenotes/notes/farmreader-kwargs-130e0b1a89f75fef.yaml new file mode 100644 index 0000000000..664c327d27 --- /dev/null +++ b/releasenotes/notes/farmreader-kwargs-130e0b1a89f75fef.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Add model_kwargs to FARMReader to allow loading in fp16 at inference time diff --git a/releasenotes/notes/feat-add-latest-openai-embeddings-models-759c575ebee93780.yaml b/releasenotes/notes/feat-add-latest-openai-embeddings-models-759c575ebee93780.yaml new file mode 100644 index 0000000000..9b7aaa82e5 --- /dev/null +++ b/releasenotes/notes/feat-add-latest-openai-embeddings-models-759c575ebee93780.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Add support for latest OpenAI embedding models `text-embedding-3-large` and `text-embedding-3-small`. diff --git a/releasenotes/notes/fix-load-from-deepset-cloud-8a86053ccb246494.yaml b/releasenotes/notes/fix-load-from-deepset-cloud-8a86053ccb246494.yaml new file mode 100644 index 0000000000..98179516be --- /dev/null +++ b/releasenotes/notes/fix-load-from-deepset-cloud-8a86053ccb246494.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Fix `Pipeline.load_from_deepset_cloud` to work with the latest version of deepset Cloud. diff --git a/releasenotes/notes/fix-os-es-get-metadata-values-by-key-d5e34c79998c322d.yaml b/releasenotes/notes/fix-os-es-get-metadata-values-by-key-d5e34c79998c322d.yaml new file mode 100644 index 0000000000..c7f5b4a12b --- /dev/null +++ b/releasenotes/notes/fix-os-es-get-metadata-values-by-key-d5e34c79998c322d.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Fixes `SearchEngineDocumentStore.get_metadata_values_by_key` method to make use of `self.index` if no index is provided. diff --git a/releasenotes/notes/fix-output-parser-in-prompt-template-22975013f44c435b.yaml b/releasenotes/notes/fix-output-parser-in-prompt-template-22975013f44c435b.yaml new file mode 100644 index 0000000000..b064c6cb69 --- /dev/null +++ b/releasenotes/notes/fix-output-parser-in-prompt-template-22975013f44c435b.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Fixes OutputParser usage in PromptTemplate after making invocation context immutable in https://github.com/deepset-ai/haystack/pull/7510. diff --git a/releasenotes/notes/fix-pipeline-with-join-node-5f23a426cd4d88d9.yaml b/releasenotes/notes/fix-pipeline-with-join-node-5f23a426cd4d88d9.yaml new file mode 100644 index 0000000000..54cebe988d --- /dev/null +++ b/releasenotes/notes/fix-pipeline-with-join-node-5f23a426cd4d88d9.yaml @@ -0,0 +1,69 @@ +--- +fixes: + - | + When using a `Pipeline` with a `JoinNode` (e.g. `JoinDocuments`) all information from the previous nodes was lost + other than a few select fields (e.g. `documents`). This was due to the `JoinNode` not properly passing on + the information from the previous nodes. This has been fixed and now all information from the previous nodes is + passed on to the next node in the pipeline. + + For example, this is a pipeline that rewrites the `query` during pipeline execution combined with a hybrid retrieval + setup that requires a `JoinDocuments` node. Specifically the first prompt node rewrites the `query` to fix all + spelling errors, and this new `query` is used for retrieval. And now the `JoinDocuments` node will now pass on the + rewritten `query` so it can be used by the `QAPromptNode` node whereas before it would pass on the original query. + ```python + from haystack import Pipeline + from haystack.nodes import BM25Retriever, EmbeddingRetriever, PromptNode, Shaper, JoinDocuments, PromptTemplate + from haystack.document_stores import InMemoryDocumentStore + + document_store = InMemoryDocumentStore(use_bm25=True) + dicts = [{"content": "The capital of Germany is Berlin."}, {"content": "The capital of France is Paris."}] + document_store.write_documents(dicts) + + query_prompt_node = PromptNode( + model_name_or_path="gpt-3.5-turbo", + api_key="", + default_prompt_template=PromptTemplate("You are a spell checker. Given a user query return the same query with all spelling errors fixed.\nUser Query: {query}\nSpell Checked Query:") + ) + shaper = Shaper( + func="join_strings", + inputs={"strings": "results"}, + outputs=["query"], + ) + qa_prompt_node = PromptNode( + model_name_or_path="gpt-3.5-turbo", + api_key="", + default_prompt_template=PromptTemplate("Answer the user query. Query: {query}") + ) + sparse_retriever = BM25Retriever( + document_store=document_store, + top_k=2 + ) + dense_retriever = EmbeddingRetriever( + document_store=document_store, + embedding_model="intfloat/e5-base-v2", + model_format="sentence_transformers", + top_k=2 + ) + document_store.update_embeddings(dense_retriever) + + pipeline = Pipeline() + pipeline.add_node(component=query_prompt_node, name="QueryPromptNode", inputs=["Query"]) + pipeline.add_node(component=shaper, name="ListToString", inputs=["QueryPromptNode"]) + pipeline.add_node(component=sparse_retriever, name="BM25", inputs=["ListToString"]) + pipeline.add_node(component=dense_retriever, name="Embedding", inputs=["ListToString"]) + pipeline.add_node( + component=JoinDocuments(join_mode="concatenate"), name="Join", inputs=["BM25", "Embedding"] + ) + pipeline.add_node(component=qa_prompt_node, name="QAPromptNode", inputs=["Join"]) + + out = pipeline.run(query="What is the captial of Grmny?", debug=True) + print(out["invocation_context"]) + # Before Fix + # {'query': 'What is the captial of Grmny?', <-- Original Query!! + # 'results': ['The capital of Germany is Berlin.'], + # 'prompts': ['Answer the user query. Query: What is the captial of Grmny?'], <-- Original Query!! + # After Fix + # {'query': 'What is the capital of Germany?', <-- Rewritten Query!! + # 'results': ['The capital of Germany is Berlin.'], + # 'prompts': ['Answer the user query. Query: What is the capital of Germany?'], <-- Rewritten Query!! + ``` diff --git a/releasenotes/notes/fix-promptnode-empty-inputs-c050c2040d489f9e.yaml b/releasenotes/notes/fix-promptnode-empty-inputs-c050c2040d489f9e.yaml new file mode 100644 index 0000000000..bef5427a14 --- /dev/null +++ b/releasenotes/notes/fix-promptnode-empty-inputs-c050c2040d489f9e.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + When passing empty inputs (such as `query=""`) to PromptNode, the node would raise an error. This has been fixed. diff --git a/releasenotes/notes/join-docs-weighting-rrf-c52ba00a25004fd4.yaml b/releasenotes/notes/join-docs-weighting-rrf-c52ba00a25004fd4.yaml new file mode 100644 index 0000000000..23a31d911e --- /dev/null +++ b/releasenotes/notes/join-docs-weighting-rrf-c52ba00a25004fd4.yaml @@ -0,0 +1,6 @@ +--- +enhancements: + - | + Make `JoinDocuments` sensitive to `weights` parameter when + `join_mode` is reciprocal rank fusion. Add score normalization + for `JoinDocuments` when `join_mode` is reciprocal rank fusion. diff --git a/releasenotes/notes/mongodb-vector-search-2f40a9c67eed6e7c.yaml b/releasenotes/notes/mongodb-vector-search-2f40a9c67eed6e7c.yaml new file mode 100644 index 0000000000..959b5148ac --- /dev/null +++ b/releasenotes/notes/mongodb-vector-search-2f40a9c67eed6e7c.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Fixed a bug that caused the EmbeddingRetriever to return no documents when used with a MongoDBAtlasDocumentStore. MongoDBAtlasDocumentStore now accepts a vector_search_index parameter, which needs to be created before in the MongoDB Atlas Web UI following [their documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#std-label-avs-create-index). diff --git a/releasenotes/notes/optimize-pinecone-upsert.yaml b/releasenotes/notes/optimize-pinecone-upsert.yaml new file mode 100644 index 0000000000..9967db94ce --- /dev/null +++ b/releasenotes/notes/optimize-pinecone-upsert.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Optimize documents upsert in PineconeDocumentStore (write_documents) by enabling asynchronous requests. diff --git a/releasenotes/notes/override-api-base-67bc046a5cc5f46d.yaml b/releasenotes/notes/override-api-base-67bc046a5cc5f46d.yaml new file mode 100644 index 0000000000..bf0eb5b3c5 --- /dev/null +++ b/releasenotes/notes/override-api-base-67bc046a5cc5f46d.yaml @@ -0,0 +1,8 @@ +--- +enhancements: + - | + API_BASE can now be passed as an optional parameter in the getting_started sample. Only openai provider is supported in this set of changes. + PromptNode and PromptModel were enhanced to allow passing of this parameter. + This allows RAG against a local endpoint (e.g, http://localhost:1234/v1), so long as it is OpenAI compatible (such as LM Studio) + + Logging in the getting started sample was made more verbose, to make it easier for people to see what was happening under the covers. diff --git a/releasenotes/notes/pinecone-change-dummy-vector-b9fa90f2de6fb846.yaml b/releasenotes/notes/pinecone-change-dummy-vector-b9fa90f2de6fb846.yaml new file mode 100644 index 0000000000..bb0082d0c0 --- /dev/null +++ b/releasenotes/notes/pinecone-change-dummy-vector-b9fa90f2de6fb846.yaml @@ -0,0 +1,6 @@ +--- +fixes: + - | + Change the dummy vector used internally in the Pinecone Document Store. + A recent change to the Pinecone API does not allow to use vectors filled with zeros + as was the previous dummy vector. diff --git a/releasenotes/notes/preprocessor-split-by-page-2de0b59175f4203e.yaml b/releasenotes/notes/preprocessor-split-by-page-2de0b59175f4203e.yaml new file mode 100644 index 0000000000..bd47ae96ea --- /dev/null +++ b/releasenotes/notes/preprocessor-split-by-page-2de0b59175f4203e.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Added new option split_by="page" to the preprocessor so we can chunk documents by page break. diff --git a/releasenotes/notes/prompt-model-invocation-layer-e7a69a3ac3beb5a7.yaml b/releasenotes/notes/prompt-model-invocation-layer-e7a69a3ac3beb5a7.yaml new file mode 100644 index 0000000000..2f2a6bb497 --- /dev/null +++ b/releasenotes/notes/prompt-model-invocation-layer-e7a69a3ac3beb5a7.yaml @@ -0,0 +1,6 @@ +--- +enhancements: + - | + Change `PromptModel` constructor parameter `invocation_layer_class` to accept a `str` too. + If a `str` is used the invocation layer class will be imported and used. + This should ease serialisation to YAML when using `invocation_layer_class` with `PromptModel`. diff --git a/releasenotes/notes/ranker-model-kwargs-0f60508b69d7d46e.yaml b/releasenotes/notes/ranker-model-kwargs-0f60508b69d7d46e.yaml new file mode 100644 index 0000000000..5cc82763fa --- /dev/null +++ b/releasenotes/notes/ranker-model-kwargs-0f60508b69d7d46e.yaml @@ -0,0 +1,5 @@ +--- +enhancements: + - | + Add model_kwargs argument to SentenceTransformersRanker to be able to pass through HF transformers loading options + diff --git a/releasenotes/notes/remove-fetch_from_http-787be70bee86ebe4.yaml b/releasenotes/notes/remove-fetch_from_http-787be70bee86ebe4.yaml new file mode 100644 index 0000000000..4155dcce9c --- /dev/null +++ b/releasenotes/notes/remove-fetch_from_http-787be70bee86ebe4.yaml @@ -0,0 +1,8 @@ +--- +prelude: > + The utility functions `fetch_archive_from_http`, `build_pipeline` and + `add_example_data` were removed from Haystack. +upgrade: + - | + We recommend replacing calls to the `fetch_archive_from_http` function with + other tools available in Python or in the operating system of use. diff --git a/releasenotes/notes/remove-pymupdf-15fc66b581538adb.yaml b/releasenotes/notes/remove-pymupdf-15fc66b581538adb.yaml new file mode 100644 index 0000000000..1ee09c37d1 --- /dev/null +++ b/releasenotes/notes/remove-pymupdf-15fc66b581538adb.yaml @@ -0,0 +1,7 @@ +--- +prelude: > + This release changes the `PDFToTextConverter` so that it doesn't support PyMuPDF anymore. + The converter will always assume `xpdf` is used by default. +upgrade: + - | + To keep using PyMuPDF you must create a custom node, you can use the previous Haystack version for inspiration. diff --git a/releasenotes/notes/remove_answer_generator-e7100f82c1859fcb.yaml b/releasenotes/notes/remove_answer_generator-e7100f82c1859fcb.yaml new file mode 100644 index 0000000000..d4180aef42 --- /dev/null +++ b/releasenotes/notes/remove_answer_generator-e7100f82c1859fcb.yaml @@ -0,0 +1,5 @@ +--- +upgrade: + - | + Remove deprecated `OpenAIAnswerGenerator`, `BaseGenerator`, `GenerativeQAPipeline` and related tests. + GenerativeQA Pipelines should use PromptNode instead. See https://haystack.deepset.ai/tutorials/22_pipeline_with_promptnode diff --git a/releasenotes/notes/review-openai-context-windows-b48c2357da8ccac7.yaml b/releasenotes/notes/review-openai-context-windows-b48c2357da8ccac7.yaml new file mode 100644 index 0000000000..b011623836 --- /dev/null +++ b/releasenotes/notes/review-openai-context-windows-b48c2357da8ccac7.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Review and update context windows for OpenAI GPT models. \ No newline at end of file diff --git a/releasenotes/notes/route-documents-metadata-values-types-7b6bdbc916d2624b.yaml b/releasenotes/notes/route-documents-metadata-values-types-7b6bdbc916d2624b.yaml new file mode 100644 index 0000000000..f5eb9b1551 --- /dev/null +++ b/releasenotes/notes/route-documents-metadata-values-types-7b6bdbc916d2624b.yaml @@ -0,0 +1,6 @@ +--- +fixes: + - | + The types of meta data values accepted by RouteDocuments was unnecessarily restricted to string types. + This causes validation errors (for example when loading from a yaml file) if a user tries to use a boolean type for example. + We add boolean and int types as valid types for metadata_values. diff --git a/releasenotes/notes/safe-fetch-4ba829def3241eec.yaml b/releasenotes/notes/safe-fetch-4ba829def3241eec.yaml new file mode 100644 index 0000000000..921e88dbb5 --- /dev/null +++ b/releasenotes/notes/safe-fetch-4ba829def3241eec.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Add previously removed `fetch_archive_from_http` util function to fetch zip and gzip archives from url diff --git a/releasenotes/notes/specify-sentence-transformer-model-version-55b633e8b294189c.yaml b/releasenotes/notes/specify-sentence-transformer-model-version-55b633e8b294189c.yaml new file mode 100644 index 0000000000..2c43594726 --- /dev/null +++ b/releasenotes/notes/specify-sentence-transformer-model-version-55b633e8b294189c.yaml @@ -0,0 +1,5 @@ +--- +fixes: + - | + Fix bug causing latest version of sentence transformer model always being downloaded, even if specific version + is given. diff --git a/releasenotes/notes/summarizer-kwargs-speed-148714d268773f60.yaml b/releasenotes/notes/summarizer-kwargs-speed-148714d268773f60.yaml new file mode 100644 index 0000000000..d734894796 --- /dev/null +++ b/releasenotes/notes/summarizer-kwargs-speed-148714d268773f60.yaml @@ -0,0 +1,5 @@ +--- +enhancements: + - | + Use batching in the predict method since multiple documents are usually passed at inference time. + Allow the model to be loaded in torch.float16 by adding pipeline_kwargs to the init method diff --git a/releasenotes/notes/support-gpt-3.5-turbo-1106-6de58fee16cc0c0a.yaml b/releasenotes/notes/support-gpt-3.5-turbo-1106-6de58fee16cc0c0a.yaml new file mode 100644 index 0000000000..9275080527 --- /dev/null +++ b/releasenotes/notes/support-gpt-3.5-turbo-1106-6de58fee16cc0c0a.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Correctly calculate the max token limit for gpt-3.5-turbo-1106 diff --git a/releasenotes/notes/support-hf-inference-gated-repos-c04be10438e08501.yaml b/releasenotes/notes/support-hf-inference-gated-repos-c04be10438e08501.yaml new file mode 100644 index 0000000000..87fb843b5e --- /dev/null +++ b/releasenotes/notes/support-hf-inference-gated-repos-c04be10438e08501.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Support gated repos for Huggingface inference. diff --git a/releasenotes/notes/update-transformers-e83863bbe69ddceb.yaml b/releasenotes/notes/update-transformers-e83863bbe69ddceb.yaml new file mode 100644 index 0000000000..56bae81b1b --- /dev/null +++ b/releasenotes/notes/update-transformers-e83863bbe69ddceb.yaml @@ -0,0 +1,8 @@ +--- +upgrade: + - | + Upgrade the transformers dependency requirement to "transformers>=4.46,<5.0". + +enhancements: + - | + Updated `tokenizer.json` URL for Anthropic models as the old URL was no longer available. diff --git a/releasenotes/notes/upgrade-ntlk-1e94de2d6f5dd3b6.yaml b/releasenotes/notes/upgrade-ntlk-1e94de2d6f5dd3b6.yaml new file mode 100644 index 0000000000..7aa7817cef --- /dev/null +++ b/releasenotes/notes/upgrade-ntlk-1e94de2d6f5dd3b6.yaml @@ -0,0 +1,9 @@ +fixes: + - | + Upgrades ntlk to 3.9.1 as prior versions are affect by https://nvd.nist.gov/vuln/detail/CVE-2024-39705. +upgrade: + - | + Upgrades ntlk to 3.9.1 as prior versions are affect by https://nvd.nist.gov/vuln/detail/CVE-2024-39705. Due to these security vulnerabilities, it is not possible to use custom NLTK tokenizer models with the new version (for example in PreProcessor). Users can still use built-in nltk tokenizers by specifying the language parameter in the PreProcessor. See PreProcessor documentation for more details. +enhancements: + - | + Pins sentence-transformers<=3.0.0,>=2.3.1 and python-pptx<=1.0 to avoid some minor typing incompatibilities with the newer version of the respective libraries. diff --git a/releasenotes/notes/verify-embed-dim-docustore-retriever-9ac88d8f0adc8a32.yaml b/releasenotes/notes/verify-embed-dim-docustore-retriever-9ac88d8f0adc8a32.yaml new file mode 100644 index 0000000000..f8b802b379 --- /dev/null +++ b/releasenotes/notes/verify-embed-dim-docustore-retriever-9ac88d8f0adc8a32.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Add a check to verify that the embedding dimension set in the FAISS Document Store and retriever are equal before running embedding calculations. diff --git a/releasenotes/notes/weaviate-handle-empty-list-8d3432080f8bfefd.yaml b/releasenotes/notes/weaviate-handle-empty-list-8d3432080f8bfefd.yaml new file mode 100644 index 0000000000..df6129bf80 --- /dev/null +++ b/releasenotes/notes/weaviate-handle-empty-list-8d3432080f8bfefd.yaml @@ -0,0 +1,6 @@ +--- +fixes: + - | + Fixed a bug that made it impossible to write Documents + to Weaviate when some of the fields were empty lists + (e.g. `split_overlap` for preprocessed documents). diff --git a/rest_api/pyproject.toml b/rest_api/pyproject.toml index e5eda60e7d..2ee766c221 100644 --- a/rest_api/pyproject.toml +++ b/rest_api/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ ] dependencies = [ "farm-haystack", - "fastapi", + "fastapi~=0.108.0", "uvicorn<1", "python-multipart<1", # optional FastAPI dependency for form data "pynvml", diff --git a/test/conftest.py b/test/conftest.py index 9e6fdb7426..69ba93ece4 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -28,7 +28,6 @@ from haystack.nodes import ( BaseReader, BaseRetriever, - BaseGenerator, BaseSummarizer, BaseTranslator, DenseRetriever, @@ -321,11 +320,6 @@ def embed_documents(self, documents: List[Document]): return np.full((len(documents), 768), 0.5) -class MockSeq2SegGenerator(BaseGenerator): - def predict(self, query: str, documents: List[Document], top_k: Optional[int], max_tokens: Optional[int]) -> Dict: - pass - - class MockSummarizer(BaseSummarizer): def predict_batch( self, documents: Union[List[Document], List[List[Document]]], batch_size: Optional[int] = None @@ -441,6 +435,7 @@ def docs_all_formats() -> List[Union[Document, Dict[str, Any]]]: "date_field": "2019-10-01", "numeric_field": 5.0, "list_field": ["item0.1", "item0.2"], + "page_number": 1, }, # "dict" format { @@ -451,6 +446,7 @@ def docs_all_formats() -> List[Union[Document, Dict[str, Any]]]: "date_field": "2020-03-01", "numeric_field": 5.5, "list_field": ["item1.1", "item1.2"], + "page_number": 2, }, }, # Document object @@ -462,6 +458,7 @@ def docs_all_formats() -> List[Union[Document, Dict[str, Any]]]: "date_field": "2018-10-01", "numeric_field": 4.5, "list_field": ["item2.1", "item2.2"], + "page_number": 3, }, ), Document( @@ -472,6 +469,7 @@ def docs_all_formats() -> List[Union[Document, Dict[str, Any]]]: "date_field": "2021-02-01", "numeric_field": 3.0, "list_field": ["item3.1", "item3.2"], + "page_number": 4, }, ), Document( @@ -482,6 +480,7 @@ def docs_all_formats() -> List[Union[Document, Dict[str, Any]]]: "date_field": "2019-01-01", "numeric_field": 0.0, "list_field": ["item4.1", "item4.2"], + "page_number": 5, }, ), ] @@ -896,11 +895,6 @@ def sample_txt_file_paths_list(samples_path): return list((samples_path / "docs").glob("*.txt")) -@pytest.fixture -def preview_samples_path(): - return Path(__file__).parent / "preview" / "test_files" - - @pytest.fixture(autouse=True) def request_blocker(request: pytest.FixtureRequest, monkeypatch): """ diff --git a/test/document_stores/test_faiss.py b/test/document_stores/test_faiss.py index b2c0171ae0..0ed24bc487 100644 --- a/test/document_stores/test_faiss.py +++ b/test/document_stores/test_faiss.py @@ -42,6 +42,21 @@ def test_index_mutual_exclusive_args(self, tmp_path): isolation_level="AUTOCOMMIT", ) + @pytest.mark.unit + def test_validate_embedding_dimension_unequal_embedding_dim(self, ds, documents): + retriever = MockDenseRetriever(document_store=ds, embedding_dim=384) + ds.write_documents(documents) + assert ds.get_document_count() == len(documents) + with pytest.raises(RuntimeError): + ds._validate_embedding_dimension(retriever) + + @pytest.mark.unit + def test_validate_embedding_dimension_equal_embedding_dim(self, ds, documents): + retriever = MockDenseRetriever(document_store=ds, embedding_dim=768) + ds.write_documents(documents) + assert ds.get_document_count() == len(documents) + ds._validate_embedding_dimension(retriever) + @pytest.mark.integration def test_delete_index(self, ds, documents): """Contrary to other Document Stores, FAISSDocumentStore doesn't raise if the index is empty""" diff --git a/test/document_stores/test_mongodb_atlas.py b/test/document_stores/test_mongodb_atlas.py new file mode 100644 index 0000000000..f6c4a20dfa --- /dev/null +++ b/test/document_stores/test_mongodb_atlas.py @@ -0,0 +1,64 @@ +from unittest.mock import MagicMock, patch + +import pymongo +import pytest +from numpy import float32, random + +from haystack.document_stores import mongodb_atlas +from haystack.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore + + +class TestMongoDBDocumentStore: + @pytest.fixture + def mocked_ds(self): + class DSMock(MongoDBAtlasDocumentStore): + # We mock a subclass to avoid messing up the actual class object + pass + + mongodb_atlas._validate_mongo_connection_string = MagicMock() + mongodb_atlas._validate_database_name = MagicMock() + mongodb_atlas._validate_collection_name = MagicMock() + mongodb_atlas._get_collection = MagicMock() + pymongo.MongoClient = MagicMock() + + mocked_ds = DSMock( + mongo_connection_string="mongodb+srv://{mongo_atlas_username}:{mongo_atlas_password}@{mongo_atlas_host}/?{mongo_atlas_params_string}", + database_name="test_db", + collection_name="test_collection", + embedding_dim=1536, + ) + + return mocked_ds + + @pytest.mark.unit + def test_error_is_raised_if_vector_index_name_is_not_set_for_vector_search(self, mocked_ds): + with pytest.raises(ValueError): + mocked_ds.vector_search_index = None + mocked_ds.query_by_embedding(query_emb=random.rand(768)) + + @pytest.mark.unit + def test_vector_index_name_is_set_for_vector_search(self, mocked_ds): + mocked_ds.vector_search_index = "vector_search_index" + with patch.object(mocked_ds, "_get_collection", return_value=MagicMock()) as mock_get_collection: + query_emb = random.rand(768) + mocked_ds.query_by_embedding(query_emb=query_emb) + expected_emb_in_call = query_emb.astype(float32) + mocked_ds.normalize_embedding(expected_emb_in_call) + # check that the correct arguments are passed to collection.aggregate() + mock_get_collection().aggregate.assert_called() + mock_get_collection().aggregate.assert_called_once_with( + [ + { + "$vectorSearch": { + "index": "vector_search_index", + "queryVector": expected_emb_in_call.tolist(), + "path": "embedding", + "numCandidates": 100, + "limit": 10, + } + }, + {"$match": {}}, + {"$project": {"embedding": False}}, + {"$set": {"score": {"$meta": "vectorSearchScore"}}}, + ] + ) diff --git a/test/document_stores/test_pinecone.py b/test/document_stores/test_pinecone.py index 6bacb9e48b..7d92aef3e4 100644 --- a/test/document_stores/test_pinecone.py +++ b/test/document_stores/test_pinecone.py @@ -38,11 +38,16 @@ def ds(self, monkeypatch, request) -> PineconeDocumentStore: monkeypatch.setattr(f"pinecone.{fname}", function, raising=False) for cname, class_ in getmembers(pinecone_mock, isclass): monkeypatch.setattr(f"pinecone.{cname}", class_, raising=False) + params = getattr(request, "param", {}) + pods = params.get("pods", None) + pod_type = params.get("pod_type", None) return PineconeDocumentStore( api_key=os.environ.get("PINECONE_API_KEY") or "fake-pinecone-test-key", embedding_dim=768, embedding_field="embedding", + pods=pods, + pod_type=pod_type, index="haystack_tests", similarity="cosine", recreate_index=True, diff --git a/test/document_stores/test_search_engine.py b/test/document_stores/test_search_engine.py index fc819fa7ed..30f4225b07 100644 --- a/test/document_stores/test_search_engine.py +++ b/test/document_stores/test_search_engine.py @@ -4,7 +4,7 @@ import numpy as np import pytest from haystack.document_stores.search_engine import SearchEngineDocumentStore -from haystack.schema import FilterType +from haystack.schema import Document, FilterType @pytest.mark.unit @@ -60,6 +60,14 @@ def test_get_meta_values_by_key(self, ds, documents): result = ds.get_metadata_values_by_key(key="year", query="Bar") assert result == [{"count": 3, "value": "2021"}] + @pytest.mark.integration + def test_get_meta_values_by_key_with_batch_size(self, ds): + docs = [Document(f"content_{i}", meta={"name": f"name_{i}"}) for i in range(10_000)] + ds.write_documents(docs) + + result = ds.get_metadata_values_by_key(key="name", batch_size=1_000) + assert result == sorted([{"count": 1, "value": f"name_{i}"} for i in range(10_000)], key=lambda x: x["value"]) + @pytest.mark.unit def test_query_return_embedding_true(self, mocked_document_store): mocked_document_store.return_embedding = True diff --git a/test/document_stores/test_weaviate.py b/test/document_stores/test_weaviate.py index 050a042b87..4888bd306b 100644 --- a/test/document_stores/test_weaviate.py +++ b/test/document_stores/test_weaviate.py @@ -9,6 +9,7 @@ from haystack.document_stores.weaviate import WeaviateDocumentStore from haystack.schema import Document from haystack.testing import DocumentStoreBaseTestAbstract +from haystack.nodes.preprocessor import PreProcessor embedding_dim = 768 @@ -267,6 +268,33 @@ def test_get_embedding_count(self, ds, documents): ds.write_documents(documents) assert ds.get_embedding_count() == 9 + @pytest.mark.integration + def test_write_preprocessed_docs(self, ds, documents): + """ + Test that preprocessed documents can be correctly written to Weaviate + even if the meta field `_split_overlap` is an empty list for some documents. + """ + preprocessor = PreProcessor( + clean_empty_lines=True, + clean_whitespace=True, + clean_header_footer=True, + split_by="word", + split_length=5, + split_overlap=2, + split_respect_sentence_boundary=False, + ) + + longer_doc = Document(content="This is a longer document that will be split into multiple parts.") + documents.append(longer_doc) + + preprocessed_docs = preprocessor.process(documents) + + ds.write_documents(preprocessed_docs) + + docs_from_weaviate = ds.get_all_documents() + for doc in docs_from_weaviate: + assert "_split_overlap" in doc.meta + @pytest.mark.unit def test__get_auth_secret(self): # Test with username and password diff --git a/test/mocks/pinecone.py b/test/mocks/pinecone.py index e25255f2c2..0b1857eaca 100644 --- a/test/mocks/pinecone.py +++ b/test/mocks/pinecone.py @@ -45,6 +45,8 @@ def __init__( api_key: Optional[str] = None, environment: Optional[str] = None, dimension: Optional[int] = None, + pods: Optional[int] = None, + pod_type: Optional[str] = None, metric: Optional[str] = None, replicas: Optional[int] = None, shards: Optional[int] = None, @@ -55,6 +57,8 @@ def __init__( self.environment = environment self.dimension = dimension self.metric = metric + self.pods = pods + self.pod_type = pod_type self.replicas = replicas self.shards = shards self.metadata_config = metadata_config @@ -63,11 +67,12 @@ def __init__( # Mock the Pinecone Index class class Index: - def __init__(self, index: str): + def __init__(self, index: str, pool_threads: int = 1): self.index = index + self.pool_threads = pool_threads self.index_config = CONFIG["indexes"][index] - def upsert(self, vectors: List[tuple], namespace: str = ""): + def upsert(self, vectors: List[tuple], namespace: str = "", async_req: bool = False): if namespace not in self.index_config.namespaces: self.index_config.namespaces[namespace] = {} upsert_count = 0 @@ -338,6 +343,8 @@ def create_index( dimension: int, metric: str = "cosine", replicas: int = 1, + pods: int = 1, + pod_type: str = "p1.x1", shards: int = 1, metadata_config: Optional[dict] = None, ): @@ -348,6 +355,8 @@ def create_index( dimension=dimension, metric=metric, replicas=replicas, + pods=pods, + pod_type=pod_type, shards=shards, metadata_config=metadata_config, ) diff --git a/test/nodes/test_connector.py b/test/nodes/test_connector.py index b346600c0a..68f31befdb 100644 --- a/test/nodes/test_connector.py +++ b/test/nodes/test_connector.py @@ -1,14 +1,13 @@ -from typing import List - -import json -from pathlib import Path -import re import hashlib +import json import os -from unittest.mock import patch +import re +from pathlib import Path +from typing import List +from unittest.mock import Mock, patch import pytest - +from selenium.webdriver import Chrome from selenium.webdriver.common.by import By from haystack.nodes.connector.crawler import Crawler @@ -54,8 +53,9 @@ def content_in_results(crawler: Crawler, url: str, results: List[Path], expected @pytest.mark.unit -@patch("haystack.nodes.connector.crawler.webdriver") -def test_crawler_url_none_exception(webdriver): +@patch("haystack.nodes.connector.crawler.Service") +@patch("haystack.nodes.connector.crawler.selenium_webdriver") +def test_crawler_url_none_exception(service, webdriver): crawler = Crawler() with pytest.raises(ValueError): crawler.crawl() @@ -258,3 +258,23 @@ def test_crawler_depth_2_multiple_urls(test_url, tmp_path): assert content_in_results(crawler, test_url + "/page1_subpage1.html", paths) assert content_in_results(crawler, test_url + "/page1_subpage2.html", paths) assert content_in_results(crawler, test_url + "/page2_subpage1.html", paths) + + +@pytest.mark.unit +def test_crawler_custom_webdriver(): + webdriver = Mock(Chrome) + crawler = Crawler(webdriver=webdriver) + + assert webdriver is crawler.driver + + +@pytest.mark.integration +def test_crawler_pdf_download_location(samples_path, tmp_path): + crawler = Crawler(output_dir=tmp_path) + + file_name = "sample_pdf_1.pdf" + pdf_uri = (samples_path / "pdf" / file_name).absolute().as_uri() + documents = crawler.crawl(urls=[pdf_uri]) + assert len(documents) == 1 + assert (tmp_path / file_name).exists() + assert len(os.listdir(tmp_path)) == 2 diff --git a/test/nodes/test_file_converter.py b/test/nodes/test_file_converter.py index 9daee4b587..dd6ea0d241 100644 --- a/test/nodes/test_file_converter.py +++ b/test/nodes/test_file_converter.py @@ -47,125 +47,6 @@ def test_convert(Converter, samples_path): assert "Adobe Systems made the PDF specification available free of charge in 1993." in page_standard_whitespace -@pytest.mark.unit -@pytest.mark.parametrize("Converter", [PDFToTextConverter]) -def test_pdf_command_whitespaces(Converter, samples_path): - converter = Converter() - - document = converter.run(file_paths=samples_path / "pdf" / "sample pdf file with spaces on file name.pdf")[0][ - "documents" - ][0] - assert "ɪ" in document.content - - -@pytest.mark.unit -@pytest.mark.parametrize("Converter", [PDFToTextConverter]) -def test_pdf_encoding(Converter, samples_path): - converter = Converter() - - document = converter.run(file_paths=samples_path / "pdf" / "sample_pdf_5.pdf")[0]["documents"][0] - assert "Ж" in document.content - - document = converter.run(file_paths=samples_path / "pdf" / "sample_pdf_2.pdf")[0]["documents"][0] - assert "ɪ" in document.content - - -@pytest.mark.unit -@pytest.mark.parametrize("Converter", [PDFToTextConverter]) -def test_pdf_sort_by_position(Converter, samples_path): - converter = Converter(sort_by_position=True) - - document = converter.convert(file_path=samples_path / "pdf" / "sample_pdf_3.pdf")[0] - assert str(document.content).startswith("This is the second test sentence.") - - -@pytest.mark.unit -@pytest.mark.parametrize("Converter", [PDFToTextConverter]) -def test_pdf_ligatures(Converter, samples_path): - converter = Converter() - - document = converter.run(file_paths=samples_path / "pdf" / "sample_pdf_2.pdf")[0]["documents"][0] - assert "ff" not in document.content - assert "ɪ" in document.content - - document = converter.run(file_paths=samples_path / "pdf" / "sample_pdf_2.pdf", known_ligatures={})[0]["documents"][ - 0 - ] - assert "ff" in document.content - assert "ɪ" in document.content - - document = converter.run(file_paths=samples_path / "pdf" / "sample_pdf_2.pdf", known_ligatures={"ɪ": "i"})[0][ - "documents" - ][0] - assert "ff" in document.content - assert "ɪ" not in document.content - - -@pytest.mark.unit -@pytest.mark.parametrize("Converter", [PDFToTextConverter]) -def test_pdf_page_range(Converter, samples_path): - converter = Converter() - document = converter.convert(file_path=samples_path / "pdf" / "sample_pdf_1.pdf", start_page=2)[0] - pages = document.content.split("\f") - - assert ( - len(pages) == 4 - ) # the sample PDF file has four pages, we skipped first (but we wanna correct number of pages) - assert pages[0] == "" # the page 1 was skipped. - assert pages[1] != "" # the page 2 is not empty. - assert pages[2] == "" # the page 3 is empty. - - -@pytest.mark.unit -@pytest.mark.parametrize("Converter", [PDFToTextConverter]) -def test_pdf_page_range_numbers(Converter, samples_path): - converter = Converter() - document = converter.convert(file_path=samples_path / "pdf" / "sample_pdf_1.pdf", start_page=2)[0] - - preprocessor = PreProcessor( - split_by="word", split_length=5, split_overlap=0, split_respect_sentence_boundary=False, add_page_number=True - ) - documents = preprocessor.process([document]) - - assert documents[1].meta["page"] == 4 - - -@pytest.mark.unit -@pytest.mark.parametrize("Converter", [PDFToTextConverter]) -def test_pdf_parallel(Converter, samples_path): - converter = Converter(multiprocessing=True) - document = converter.convert(file_path=samples_path / "pdf" / "sample_pdf_6.pdf")[0] - - pages = document.content.split("\f") - - assert pages[0] == "This is the page 1 of the document." - assert pages[-1] == "This is the page 50 of the document." - - -@pytest.mark.unit -@pytest.mark.parametrize("Converter", [PDFToTextConverter]) -def test_pdf_parallel_page_range(Converter, samples_path): - converter = Converter(multiprocessing=True) - document = converter.convert(file_path=samples_path / "pdf" / "sample_pdf_6.pdf", start_page=2)[0] - - pages = document.content.split("\f") - - assert pages[0] == "" - assert len(pages) == 50 - - -@pytest.mark.unit -@pytest.mark.parametrize("Converter", [PDFToTextConverter]) -def test_pdf_parallel_sort_by_position(Converter, samples_path): - converter = Converter(multiprocessing=True, sort_by_position=True) - document = converter.convert(file_path=samples_path / "pdf" / "sample_pdf_6.pdf")[0] - - pages = document.content.split("\f") - - assert pages[0] == "This is the page 1 of the document." - assert pages[-1] == "This is the page 50 of the document." - - @pytest.mark.integration @pytest.mark.parametrize("Converter", [PDFToTextConverter]) def test_pdf_parallel_ocr(Converter, samples_path): @@ -422,6 +303,34 @@ def test_csv_to_document_with_wrong_qa_headers(tmp_path): node.run(file_paths=csv_path) +@pytest.mark.unit +def test_csv_to_document_with_wrong_qa_headers_raise_on_failure_true(tmp_path): + node = CsvTextConverter() + csv_path = tmp_path / "csv_qa_with_wrong_headers.csv" + rows = [ + ["wrong", "headers"], + ["What is Haystack ?", "Haystack is an NLP Framework to use transformers in your Applications."], + ] + write_as_csv(rows, csv_path) + + with pytest.raises(ValueError): + node.run(file_paths=csv_path, raise_on_failure=True) + + +@pytest.mark.unit +def test_csv_to_document_with_wrong_qa_headers_raise_on_failure_false(tmp_path): + node = CsvTextConverter() + csv_path = tmp_path / "csv_qa_with_wrong_headers.csv" + rows = [ + ["wrong", "headers"], + ["What is Haystack ?", "Haystack is an NLP Framework to use transformers in your Applications."], + ] + write_as_csv(rows, csv_path) + + result, _ = node.run(file_paths=csv_path, raise_on_failure=False) + assert len(result["documents"]) == 0 + + @pytest.mark.unit def test_csv_to_document_with_one_wrong_qa_headers(tmp_path): node = CsvTextConverter() diff --git a/test/nodes/test_filetype_classifier.py b/test/nodes/test_filetype_classifier.py index 55fb34455f..5b0f12c8da 100644 --- a/test/nodes/test_filetype_classifier.py +++ b/test/nodes/test_filetype_classifier.py @@ -166,3 +166,19 @@ def test_filetype_classifier_batched_same_media_extensions(tmp_path): output, edge = node.run_batch(test_files) assert edge == "output_1" assert output == {"file_paths": test_files} + + +@pytest.mark.unit +@pytest.mark.parametrize("file_type", ["csv", "json", "xml", "pptx", "xlsx"]) +def test_filetype_classifier_raise_on_error_disabled_unsupported_file_types(tmp_path, caplog, file_type): + node = FileTypeClassifier(raise_on_error=False) + test_file = tmp_path / f"test.{file_type}" + caplog.clear() + with caplog.at_level(logging.WARNING): + output, edge = node.run(test_file) + assert edge == "output_dead_end" + assert output == {"file_paths": [test_file]} + assert ( + f"Unsupported files of type '{file_type}' ({test_file!s}) found. Unsupported file types will be ignored" + in caplog.text + ) diff --git a/test/nodes/test_generator.py b/test/nodes/test_generator.py deleted file mode 100644 index f69663aeba..0000000000 --- a/test/nodes/test_generator.py +++ /dev/null @@ -1,208 +0,0 @@ -from unittest.mock import patch, create_autospec - -import pytest -from haystack import Pipeline -from haystack.schema import Document, Answer -from haystack.nodes.answer_generator import OpenAIAnswerGenerator -from haystack.nodes import PromptTemplate - -from ..conftest import fail_at_version - -import logging - - -@pytest.mark.unit -@fail_at_version(1, 23) -@patch("haystack.nodes.answer_generator.openai.load_openai_tokenizer") -def test_openaianswergenerator_deprecation(mock_load_tokenizer): - with pytest.warns(DeprecationWarning): - OpenAIAnswerGenerator(api_key="fake_api_key") - - -@pytest.mark.unit -@patch("haystack.nodes.answer_generator.openai.openai_request") -def test_no_openai_organization(mock_request): - with patch("haystack.nodes.answer_generator.openai.load_openai_tokenizer"): - generator = OpenAIAnswerGenerator(api_key="fake_api_key") - assert generator.openai_organization is None - - generator.predict(query="test query", documents=[Document(content="test document")]) - assert "OpenAI-Organization" not in mock_request.call_args.kwargs["headers"] - - -@pytest.mark.unit -@patch("haystack.nodes.answer_generator.openai.openai_request") -def test_openai_organization(mock_request): - with patch("haystack.nodes.answer_generator.openai.load_openai_tokenizer"): - generator = OpenAIAnswerGenerator(api_key="fake_api_key", openai_organization="fake_organization") - assert generator.openai_organization == "fake_organization" - - generator.predict(query="test query", documents=[Document(content="test document")]) - assert mock_request.call_args.kwargs["headers"]["OpenAI-Organization"] == "fake_organization" - - -@pytest.mark.unit -@patch("haystack.nodes.answer_generator.openai.openai_request") -def test_openai_answer_generator_default_api_base(mock_request): - with patch("haystack.nodes.answer_generator.openai.load_openai_tokenizer"): - generator = OpenAIAnswerGenerator(api_key="fake_api_key") - assert generator.api_base == "https://api.openai.com/v1" - generator.predict(query="test query", documents=[Document(content="test document")]) - assert mock_request.call_args.kwargs["url"] == "https://api.openai.com/v1/completions" - - -@pytest.mark.unit -@patch("haystack.nodes.answer_generator.openai.openai_request") -def test_openai_answer_generator_custom_api_base(mock_request): - with patch("haystack.nodes.answer_generator.openai.load_openai_tokenizer"): - generator = OpenAIAnswerGenerator(api_key="fake_api_key", api_base="https://fake_api_base.com") - assert generator.api_base == "https://fake_api_base.com" - generator.predict(query="test query", documents=[Document(content="test document")]) - assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/completions" - - -@pytest.mark.integration -@pytest.mark.parametrize("haystack_openai_config", ["openai", "azure"], indirect=True) -def test_openai_answer_generator(haystack_openai_config, docs): - if not haystack_openai_config: - pytest.skip("No API key found, skipping test") - - openai_generator = OpenAIAnswerGenerator( - api_key=haystack_openai_config["api_key"], - azure_base_url=haystack_openai_config.get("azure_base_url", None), - azure_deployment_name=haystack_openai_config.get("azure_deployment_name", None), - model="text-babbage-001", - top_k=1, - ) - prediction = openai_generator.predict(query="Who lives in Berlin?", documents=docs, top_k=1) - assert len(prediction["answers"]) == 1 - assert "Carla" in prediction["answers"][0].answer - - -@pytest.mark.integration -@pytest.mark.parametrize("haystack_openai_config", ["openai", "azure"], indirect=True) -def test_openai_answer_generator_custom_template(haystack_openai_config, docs): - if not haystack_openai_config: - pytest.skip("No API key found, skipping test") - - lfqa_prompt = PromptTemplate( - """Synthesize a comprehensive answer from your knowledge and the following topk most relevant paragraphs and - the given question.\n===\\Paragraphs: {context}\n===\n{query}""" - ) - node = OpenAIAnswerGenerator( - api_key=haystack_openai_config["api_key"], - azure_base_url=haystack_openai_config.get("azure_base_url", None), - azure_deployment_name=haystack_openai_config.get("azure_deployment_name", None), - model="text-babbage-001", - top_k=1, - prompt_template=lfqa_prompt, - ) - prediction = node.predict(query="Who lives in Berlin?", documents=docs, top_k=1) - assert len(prediction["answers"]) == 1 - - -@pytest.mark.integration -@pytest.mark.parametrize("haystack_openai_config", ["openai", "azure"], indirect=True) -def test_openai_answer_generator_max_token(haystack_openai_config, docs, caplog): - if not haystack_openai_config: - pytest.skip("No API key found, skipping test") - - openai_generator = OpenAIAnswerGenerator( - api_key=haystack_openai_config["api_key"], - azure_base_url=haystack_openai_config.get("azure_base_url", None), - azure_deployment_name=haystack_openai_config.get("azure_deployment_name", None), - model="text-babbage-001", - top_k=1, - ) - openai_generator.MAX_TOKENS_LIMIT = 116 - with caplog.at_level(logging.INFO): - prediction = openai_generator.predict(query="Who lives in Berlin?", documents=docs, top_k=1) - assert "Skipping all of the provided Documents" in caplog.text - assert len(prediction["answers"]) == 1 - # Can't easily check content of answer since it is generative and can change between runs - - -# mock tokenizer that splits the string -class MockTokenizer: - def encode(self, *args, **kwargs): - return str.split(*args, **kwargs) - - def tokenize(self, *args, **kwargs): - return str.split(*args, **kwargs) - - -@pytest.mark.unit -def test_build_prompt_within_max_length(): - with patch("haystack.nodes.answer_generator.openai.load_openai_tokenizer") as mock_load_tokenizer: - mock_load_tokenizer.return_value = MockTokenizer() - - generator = OpenAIAnswerGenerator(api_key="fake_key", max_tokens=50) - generator.MAX_TOKENS_LIMIT = 92 - query = "query" - documents = [Document("most relevant document"), Document("less relevant document")] - prompt_str, prompt_docs = generator._build_prompt_within_max_length(query=query, documents=documents) - - assert len(prompt_docs) == 1 - assert prompt_docs[0] == documents[0] - - -@pytest.mark.unit -def test_openai_answer_generator_pipeline_max_tokens(): - """ - tests that the max_tokens parameter is passed to the generator component in the pipeline - """ - question = "What is New York City like?" - mocked_response = "Forget NYC, I was generated by the mock method." - nyc_docs = [Document(content="New York is a cool and amazing city to live in the United States of America.")] - pipeline = Pipeline() - - # mock load_openai_tokenizer to avoid accessing the internet to init tiktoken - with patch("haystack.nodes.answer_generator.openai.load_openai_tokenizer"): - openai_generator = OpenAIAnswerGenerator(api_key="fake_api_key", model="text-babbage-001", top_k=1) - - pipeline.add_node(component=openai_generator, name="generator", inputs=["Query"]) - openai_generator.run = create_autospec(openai_generator.run) - openai_generator.run.return_value = ({"answers": mocked_response}, "output_1") - - result = pipeline.run(query=question, documents=nyc_docs, params={"generator": {"max_tokens": 3}}) - assert result["answers"] == mocked_response - openai_generator.run.assert_called_with(query=question, documents=nyc_docs, max_tokens=3) - - -@pytest.mark.unit -@patch("haystack.nodes.answer_generator.openai.OpenAIAnswerGenerator.predict") -def test_openai_answer_generator_run_with_labels_and_isolated_node_eval(patched_predict, eval_labels): - label = eval_labels[0] - query = label.query - document = label.labels[0].document - - patched_predict.return_value = { - "answers": [Answer(answer=label.labels[0].answer.answer, document_ids=[document.id])] - } - with patch("haystack.nodes.answer_generator.openai.load_openai_tokenizer"): - openai_generator = OpenAIAnswerGenerator(api_key="fake_api_key", model="text-babbage-001", top_k=1) - result, _ = openai_generator.run(query=query, documents=[document], labels=label, add_isolated_node_eval=True) - - assert "answers_isolated" in result - - -@pytest.mark.unit -@patch("haystack.nodes.answer_generator.base.BaseGenerator.predict_batch") -def test_openai_answer_generator_run_batch_with_labels_and_isolated_node_eval(patched_predict_batch, eval_labels): - queries = [label.query for label in eval_labels] - documents = [[label.labels[0].document] for label in eval_labels] - - patched_predict_batch.return_value = { - "queries": queries, - "answers": [ - [Answer(answer=label.labels[0].answer.answer, document_ids=[label.labels[0].document.id])] - for label in eval_labels - ], - } - with patch("haystack.nodes.answer_generator.openai.load_openai_tokenizer"): - openai_generator = OpenAIAnswerGenerator(api_key="fake_api_key", model="text-babbage-001", top_k=1) - result, _ = openai_generator.run_batch( - queries=queries, documents=documents, labels=eval_labels, add_isolated_node_eval=True - ) - - assert "answers_isolated" in result diff --git a/test/nodes/test_join_documents.py b/test/nodes/test_join_documents.py index 463aeaa577..5682920639 100644 --- a/test/nodes/test_join_documents.py +++ b/test/nodes/test_join_documents.py @@ -1,8 +1,9 @@ import pytest -from haystack import Document +from haystack import Document, Pipeline from haystack.nodes.other.join_docs import JoinDocuments +from copy import deepcopy @pytest.mark.unit @@ -113,3 +114,60 @@ def test_joindocuments_concatenate_duplicate_docs_null_score(): result, _ = join_docs.run(inputs) assert len(result["documents"]) == 3 assert result["documents"] == expected_outputs["documents"] + + +@pytest.mark.unit +def test_joindocuments_rrf_weights(): + """ + Test that the reciprocal rank fusion method correctly handles weights. + """ + inputs_none = [ + { + "documents": [ + Document(content="text document 1", content_type="text", score=0.2), + Document(content="text document 2", content_type="text", score=0.3), + ] + }, + { + "documents": [ + Document(content="text document 3", content_type="text", score=0.7), + Document(content="text document 4", content_type="text", score=None), + ] + }, + ] + + inputs_even = deepcopy(inputs_none) + inputs_uneven = deepcopy(inputs_none) + + join_docs_none = JoinDocuments(join_mode="reciprocal_rank_fusion") + result_none, _ = join_docs_none.run(inputs_none) + join_docs_even = JoinDocuments(join_mode="reciprocal_rank_fusion", weights=[0.5, 0.5]) + result_even, _ = join_docs_even.run(inputs_even) + join_docs_uneven = JoinDocuments(join_mode="reciprocal_rank_fusion", weights=[0.7, 0.3]) + result_uneven, _ = join_docs_uneven.run(inputs_uneven) + + assert result_none["documents"] == result_even["documents"] + assert result_uneven["documents"] != result_none["documents"] + assert result_uneven["documents"][0].score > result_none["documents"][0].score + + +@pytest.mark.unit +def test_join_node_empty_documents(): + pipe = Pipeline() + join_node = JoinDocuments(join_mode="concatenate") + pipe.add_node(component=join_node, name="Join", inputs=["Query"]) + + # Test single document lists + output = pipe.run(query="test", documents=[]) + assert len(output["documents"]) == 0 + + +@pytest.mark.unit +def test_join_node_none_documents(): + pipe = Pipeline() + join_node = JoinDocuments(join_mode="concatenate") + pipe.add_node(component=join_node, name="Join", inputs=["Query"]) + + # Test single document lists + output = pipe.run(query="test", documents=None) + assert len(output["documents"]) == 0 diff --git a/test/nodes/test_link_content_fetcher.py b/test/nodes/test_link_content_fetcher.py index 54c0220e68..05682c763a 100644 --- a/test/nodes/test_link_content_fetcher.py +++ b/test/nodes/test_link_content_fetcher.py @@ -36,7 +36,6 @@ def test_init(): assert r.processor is None assert isinstance(r.handlers, dict) assert "text/html" in r.handlers - assert "application/pdf" in r.handlers @pytest.mark.unit @@ -49,7 +48,6 @@ def test_init_with_preprocessor(): assert r.processor == pre_processor_mock assert isinstance(r.handlers, dict) assert "text/html" in r.handlers - assert "application/pdf" in r.handlers @pytest.mark.unit @@ -65,7 +63,6 @@ def fake_but_valid_video_content_handler(response: Response) -> Optional[str]: assert isinstance(r.handlers, dict) assert "text/html" in r.handlers - assert "application/pdf" in r.handlers assert "video/mp4" in r.handlers diff --git a/test/nodes/test_preprocessor.py b/test/nodes/test_preprocessor.py index 604654d578..551e63f4e2 100644 --- a/test/nodes/test_preprocessor.py +++ b/test/nodes/test_preprocessor.py @@ -12,7 +12,7 @@ from transformers import AutoTokenizer, PreTrainedTokenizerBase from haystack import Document -from haystack.nodes.file_converter.pdf import PDFToTextConverter +from haystack.nodes.file_converter.pdf_xpdf import PDFToTextConverter from haystack.nodes.preprocessor.preprocessor import PreProcessor @@ -175,6 +175,7 @@ def test_preprocess_sentence_split_custom_models_non_default_language(split_leng @pytest.mark.unit @pytest.mark.parametrize("split_length_and_results", [(1, 8), (8, 1)]) +@pytest.mark.skip(reason="Skipped after upgrade to nltk 3.9, can't load this model pt anymore") def test_preprocess_sentence_split_custom_models(split_length_and_results, samples_path): split_length, expected_documents_count = split_length_and_results @@ -219,6 +220,66 @@ def test_preprocess_word_split(): assert len(documents) == 15 +@pytest.mark.unit +def test_preprocess_page_split(): + doc = Document( + content="This is a document on page 1.\fThis is a document on page 2.\fThis is a document on page 3." + ) + output = PreProcessor( + split_by="page", split_length=1, split_respect_sentence_boundary=False, split_overlap=0, add_page_number=True + ).run([doc])[0]["documents"] + assert len(output) == 3 + assert output[0] == Document(content="This is a document on page 1.", meta={"_split_id": 0, "page": 1}) + assert output[1] == Document(content="This is a document on page 2.", meta={"_split_id": 1, "page": 2}) + assert output[2] == Document(content="This is a document on page 3.", meta={"_split_id": 2, "page": 3}) + + +@pytest.mark.unit +def test_preprocess_page_split_and_split_length(): + doc = Document( + content="This is a document on page 1.\fThis is a document on page 2.\fThis is a document on page 3." + ) + output = PreProcessor( + split_by="page", split_length=2, split_respect_sentence_boundary=False, split_overlap=0, add_page_number=True + ).run([doc])[0]["documents"] + assert len(output) == 2 + assert output[0] == Document( + content="This is a document on page 1.\fThis is a document on page 2.", meta={"_split_id": 0, "page": 1} + ) + assert output[1] == Document(content="This is a document on page 3.", meta={"_split_id": 1, "page": 3}) + + +@pytest.mark.unit +def test_preprocess_page_split_and_split_overlap(): + doc = Document( + content="This is a document on page 1.\fThis is a document on page 2.\fThis is a document on page 3." + ) + output = PreProcessor( + split_by="page", split_length=2, split_respect_sentence_boundary=False, split_overlap=1, add_page_number=True + ).run([doc])[0]["documents"] + assert len(output) == 2 + assert output[0].content == "This is a document on page 1.\fThis is a document on page 2." + assert output[0].meta["_split_id"] == 0 + assert output[0].meta["page"] == 1 + assert output[1].content == "This is a document on page 2.\fThis is a document on page 3." + assert output[1].meta["_split_id"] == 1 + assert output[1].meta["page"] == 2 + + +@pytest.mark.unit +def test_preprocess_page_split_with_empty_pages(): + doc = Document( + content="This is a document on page 1.\f\fThis is a document on page 3.\f\fThis is a document on page 5." + ) + output = PreProcessor( + split_by="page", split_length=1, split_respect_sentence_boundary=False, split_overlap=0, add_page_number=True + ).run([doc])[0]["documents"] + assert len(output) == 3 + assert output[0] == Document(content="This is a document on page 1.", meta={"_split_id": 0, "page": 1}) + assert output[1] == Document(content="This is a document on page 3.", meta={"_split_id": 1, "page": 3}) + assert output[2] == Document(content="This is a document on page 5.", meta={"_split_id": 2, "page": 5}) + + @pytest.mark.unit def test_preprocess_tiktoken_token_split(mock_tiktoken_tokenizer): raw_docs = [ @@ -240,7 +301,7 @@ def test_preprocess_tiktoken_token_split(mock_tiktoken_tokenizer): enc.encode(d.content, allowed_special="all", disallowed_special=()) for d in token_split_docs_not_respecting_sentences ] - assert all([len(d) <= split_length for d in split_documents_encoded]) + assert all(len(d) <= split_length for d in split_documents_encoded) token_split_docs_respecting_sentences = PreProcessor( split_by="token", split_length=split_length, @@ -269,7 +330,7 @@ def test_preprocess_huggingface_token_split(mock_huggingface_tokenizer): ).process(docs) assert len(token_split_docs_not_respecting_sentences) == 8 split_documents_retokenized = [tokenizer.tokenize(d.content) for d in token_split_docs_not_respecting_sentences] - assert all([len(d) <= split_length for d in split_documents_retokenized]) + assert all(len(d) <= split_length for d in split_documents_retokenized) token_split_docs_respecting_sentences = PreProcessor( split_by="token", split_length=split_length, diff --git a/test/nodes/test_ranker.py b/test/nodes/test_ranker.py index baedd00860..8d7cb83db4 100644 --- a/test/nodes/test_ranker.py +++ b/test/nodes/test_ranker.py @@ -234,6 +234,18 @@ def test_ranker(docs, mock_transformer_model, mock_transformer_tokenizer): assert results[0] == docs[4] +@pytest.mark.unit +def test_init_called_with(): + with patch("haystack.nodes.SentenceTransformersRanker.__init__") as mock_ranker_init: + mock_ranker_init.return_value = None + _ = SentenceTransformersRanker( + model_name_or_path="fake_model", use_gpu=False, model_kwargs={"torch_dtype": torch.float16} + ) + mock_ranker_init.assert_called_once_with( + model_name_or_path="fake_model", use_gpu=False, model_kwargs={"torch_dtype": torch.float16} + ) + + @pytest.mark.unit def test_ranker_run(docs, mock_transformer_model, mock_transformer_tokenizer): with patch("torch.nn.DataParallel"): diff --git a/test/nodes/test_reader.py b/test/nodes/test_reader.py index c97bf418f5..435cd49b24 100644 --- a/test/nodes/test_reader.py +++ b/test/nodes/test_reader.py @@ -5,6 +5,7 @@ import pytest +import torch from huggingface_hub import snapshot_download from haystack.modeling.data_handler.inputs import QAInput, Question @@ -45,6 +46,71 @@ def test_reader_basic(reader): assert isinstance(reader, BaseReader) +@patch("haystack.nodes.reader.farm.QAInferencer") +def test_add_answer_page_number(mocked_qa_inferencer) -> None: + documents = [ + Document(content="This is a test.\fSentence on second page about nothing.", meta={"page_number": 1}), + Document(content="Second sentence on the second page.", meta={"page_number": 2}), + ] + reader = FARMReader(model_name_or_path="fake_model", use_gpu=False) + answer_with_meta = reader._add_answer_page_number( + documents=documents, + answer=Answer( + answer="nothing", + type="extractive", + score=0.2, + context=documents[0].content, + document_ids=[documents[0].id], + offsets_in_document=[Span(start=46, end=46 + len("nothing"))], + ), + ) + assert answer_with_meta.meta is not None + assert answer_with_meta.meta["answer_page_number"] == 2 + + +@patch("haystack.nodes.reader.farm.QAInferencer") +def test_add_answer_page_number_no_doc_page(mocked_qa_inferencer) -> None: + documents = [ + Document(content="This is a test.\fSentence on second page about nothing."), + Document(content="Second sentence on the second page."), + ] + reader = FARMReader(model_name_or_path="fake_model", use_gpu=False) + answer_with_meta = reader._add_answer_page_number( + documents=documents, + answer=Answer( + answer="nothing", + type="extractive", + score=0.2, + context=documents[0].content, + document_ids=[documents[0].id], + offsets_in_document=[Span(start=46, end=46 + len("nothing"))], + ), + ) + assert answer_with_meta.meta == {} + + +@patch("haystack.nodes.reader.farm.QAInferencer") +def test_add_answer_page_number_with_meta(mocked_qa_inferencer) -> None: + documents = [ + Document(content="This is a test.\fSentence on second page about nothing.", meta={"page_number": 1}), + Document(content="Second sentence on the second page."), + ] + reader = FARMReader(model_name_or_path="fake_model", use_gpu=False) + answer_with_meta = reader._add_answer_page_number( + documents=documents, + answer=Answer( + answer="nothing", + type="extractive", + score=0.2, + context=documents[0].content, + document_ids=[documents[0].id], + offsets_in_document=[Span(start=46, end=46 + len("nothing"))], + meta={"test": 1}, + ), + ) + assert answer_with_meta.meta == {"test": 1, "answer_page_number": 2} + + def test_output(reader, docs): prediction = reader.predict(query="Who lives in Berlin?", documents=docs, top_k=5) assert prediction is not None @@ -58,6 +124,8 @@ def test_output(reader, docs): assert 0 <= prediction["answers"][0].score <= 1 assert prediction["answers"][0].context == "My name is Carla and I live in Berlin" assert len(prediction["answers"]) == 5 + if isinstance(reader, FARMReader): + assert prediction["answers"][0].meta["answer_page_number"] == 2 def test_output_batch_single_query_single_doc_list(reader, docs): @@ -514,3 +582,13 @@ def test_farmreader_predict_batch_preprocessor_batching(mocked_qa_inferencer, do # We expect 5 calls to the QAInferencer (2 queries * 5 docs / 2 batch_size) assert reader.inferencer.inference_from_objects.call_count == 5 + + +@pytest.mark.unit +def test_farmreader_init_called_with() -> None: + with patch("haystack.nodes.FARMReader.__init__") as mock_ranker_init: + mock_ranker_init.return_value = None + _ = FARMReader(model_name_or_path="fake_model", use_gpu=False, model_kwargs={"torch_dtype": torch.float16}) + mock_ranker_init.assert_called_once_with( + model_name_or_path="fake_model", use_gpu=False, model_kwargs={"torch_dtype": torch.float16} + ) diff --git a/test/nodes/test_route_documents.py b/test/nodes/test_route_documents.py index 2100aed964..2f9ea26338 100644 --- a/test/nodes/test_route_documents.py +++ b/test/nodes/test_route_documents.py @@ -52,7 +52,7 @@ def test_routedocuments_by_content_type_return_remaining(docs_diff_types): @pytest.mark.unit -def test_routedocuments_by_metafield(docs): +def test_routedocuments_by_metafield_str(docs): route_documents = RouteDocuments(split_by="meta_field", metadata_values=["test1", "test3", "test5"]) assert route_documents.outgoing_edges == 3 result, _ = route_documents.run(docs) @@ -65,6 +65,40 @@ def test_routedocuments_by_metafield(docs): assert result["output_3"][0].meta["meta_field"] == "test5" +@pytest.mark.unit +def test_routedocuments_by_metafield_int(): + docs = [ + Document(content="doc 1", meta={"meta_field": 1}), + Document(content="doc 2", meta={"meta_field": 1}), + Document(content="doc 3", meta={"meta_field": 2}), + ] + route_documents = RouteDocuments(split_by="meta_field", metadata_values=[1, 2]) + assert route_documents.outgoing_edges == 2 + result, _ = route_documents.run(docs) + assert len(result["output_1"]) == 2 + assert len(result["output_2"]) == 1 + assert "output_4" not in result + assert result["output_1"][0].meta["meta_field"] == 1 + assert result["output_2"][0].meta["meta_field"] == 2 + + +@pytest.mark.unit +def test_routedocuments_by_metafield_bool(): + docs = [ + Document(content="doc 1", meta={"meta_field": True}), + Document(content="doc 2", meta={"meta_field": True}), + Document(content="doc 3", meta={"meta_field": False}), + ] + route_documents = RouteDocuments(split_by="meta_field", metadata_values=[True, False]) + assert route_documents.outgoing_edges == 2 + result, _ = route_documents.run(docs) + assert len(result["output_1"]) == 2 + assert len(result["output_2"]) == 1 + assert "output_4" not in result + assert result["output_1"][0].meta["meta_field"] == True + assert result["output_2"][0].meta["meta_field"] == False + + @pytest.mark.unit def test_routedocuments_by_metafield_return_remaning(docs): route_documents = RouteDocuments( diff --git a/test/nodes/test_summarizer.py b/test/nodes/test_summarizer.py index b44a469614..4becef3411 100644 --- a/test/nodes/test_summarizer.py +++ b/test/nodes/test_summarizer.py @@ -1,5 +1,8 @@ import pytest +from unittest.mock import patch +import logging +import torch import haystack from haystack.utils.torch_utils import ListDataset from haystack.schema import Document @@ -83,3 +86,26 @@ def test_summarization_batch_multiple_doc_lists(summarizer): assert len(summarized_docs[0]) == len(DOCS) for expected_summary, summary in zip(EXPECTED_SUMMARIES, summarized_docs[0]): assert expected_summary == summary.meta["summary"] + + +@pytest.mark.unit +def test_init_called_with(): + with patch("haystack.nodes.TransformersSummarizer.__init__") as mock_summarizer_init: + mock_summarizer_init.return_value = None + _ = TransformersSummarizer( + model_name_or_path="fake_model", use_gpu=False, pipeline_kwargs={"torch_dtype": torch.float16} + ) + mock_summarizer_init.assert_called_once_with( + model_name_or_path="fake_model", use_gpu=False, pipeline_kwargs={"torch_dtype": torch.float16} + ) + + +@pytest.mark.unit +def test_summarization_device_warning(caplog, mock_models): + with caplog.at_level(logging.WARNING): + _ = TransformersSummarizer( + model_name_or_path="irrelevant/anyway", + use_gpu=True, + devices=[torch.device("cuda:0"), torch.device("cuda:1")], + ) + assert "Multiple devices are not supported" in caplog.text diff --git a/test/pipelines/test_eval.py b/test/pipelines/test_eval.py index 19b952c14d..5dbca833be 100644 --- a/test/pipelines/test_eval.py +++ b/test/pipelines/test_eval.py @@ -8,7 +8,6 @@ import responses from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore -from haystack.nodes.answer_generator.openai import OpenAIAnswerGenerator from haystack.nodes.preprocessor import PreProcessor from haystack.nodes.prompt.prompt_node import PromptNode from haystack.nodes.query_classifier.transformers import TransformersQueryClassifier @@ -16,7 +15,7 @@ from haystack.nodes.retriever.sparse import BM25Retriever from haystack.nodes.summarizer.transformers import TransformersSummarizer from haystack.pipelines.base import Pipeline -from haystack.pipelines import ExtractiveQAPipeline, GenerativeQAPipeline, SearchSummarizationPipeline +from haystack.pipelines import ExtractiveQAPipeline, SearchSummarizationPipeline from haystack.pipelines.standard_pipelines import ( DocumentSearchPipeline, FAQPipeline, @@ -596,109 +595,6 @@ def test_extractive_qa_eval(reader, retriever_with_docs, tmp_path, eval_labels): assert isinstance(value, float) -@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True) -@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True) -@responses.activate -def test_generative_qa_eval(retriever_with_docs, tmp_path, eval_labels): - labels = eval_labels[:1] - responses.add( - responses.POST, - "https://api.openai.com/v1/completions", - json={"choices": [{"text": "test", "finish_reason": "stop"}, {"text": "test2", "finish_reason": "stop"}]}, - status=200, - ) - responses.add_passthru("https://openaipublic.blob.core.windows.net") - generator = OpenAIAnswerGenerator(api_key="dummy", top_k=2) - pipeline = GenerativeQAPipeline(generator=generator, retriever=retriever_with_docs) - eval_result = pipeline.eval(labels=labels, params={"Retriever": {"top_k": 5}}) - - metrics = eval_result.calculate_metrics(document_scope="document_id") - - generator_result = eval_result["Generator"] - retriever_result = eval_result["Retriever"] - - expected_generator_result_columns = [ - "answer", # answer-specific - "exact_match", # answer-specific - "f1", # answer-specific - # "sas", # answer-specific optional - "exact_match_context_scope", # answer-specific - "f1_context_scope", # answer-specific - # "sas_context_scope", # answer-specific optional - "exact_match_document_id_scope", # answer-specific - "f1_document_id_scope", # answer-specific - # "sas_document_id_scope", # answer-specific optional - "exact_match_document_id_and_context_scope", # answer-specific - "f1_document_id_and_context_scope", # answer-specific - # "sas_document_id_and_context_scope", # answer-specific optional - "offsets_in_document", # answer-specific - "gold_offsets_in_documents", # answer-specific - "offsets_in_context", # answer-specific - "gold_offsets_in_contexts", # answer-specific - "gold_answers_exact_match", # answer-specific - "gold_answers_f1", # answer-specific - # "gold_answers_sas", # answer-specific optional - "document_ids", # answer-specific - "prompt", # answer-specific - ] - - expected_retriever_result_columns = [ - "gold_id_match", # doc-specific - "context_match", # doc-specific - "answer_match", # doc-specific - "gold_id_or_answer_match", # doc-specific - "gold_id_and_answer_match", # doc-specific - "gold_id_or_context_match", # doc-specific - "gold_id_and_context_match", # doc-specific - "gold_id_and_context_and_answer_match", # doc-specific - "context_and_answer_match", # doc-specific - "gold_answers_match", # doc-specific, - "document_id", # doc-specific - ] - - expected_generic_result_columns = [ - "multilabel_id", # generic - "query", # generic - "filters", # generic - "context", # generic - "gold_contexts", # generic - "gold_documents_id_match", # generic - "gold_contexts_similarity", # generic - "type", # generic - "node", # generic - "eval_mode", # generic - "rank", # generic - "gold_document_ids", # generic - "gold_answers", # generic - # "custom_document_id", # generic optional - # "gold_custom_document_ids", # generic optional - ] - - # all expected columns are part of the evaluation result dataframe - assert sorted(expected_generator_result_columns + expected_generic_result_columns + ["index"]) == sorted( - generator_result.columns - ) - assert sorted(expected_retriever_result_columns + expected_generic_result_columns + ["index"]) == sorted( - retriever_result.columns - ) - - assert generator_result["prompt"].iloc[0] is not None - - # assert metrics are floats - for node_metrics in metrics.values(): - for value in node_metrics.values(): - assert isinstance(value, float) - - eval_result.save(tmp_path) - saved_eval_result = EvaluationResult.load(tmp_path) - - for key, df in eval_result.node_results.items(): - pd.testing.assert_frame_equal(df, saved_eval_result[key]) - - loaded_metrics = saved_eval_result.calculate_metrics(document_scope="document_id") - assert metrics == loaded_metrics - - @pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True) @pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True) def test_generative_qa_w_promptnode_eval(retriever_with_docs, tmp_path, eval_labels): diff --git a/test/pipelines/test_eval_batch.py b/test/pipelines/test_eval_batch.py index e765ba292b..1ea2de69cb 100644 --- a/test/pipelines/test_eval_batch.py +++ b/test/pipelines/test_eval_batch.py @@ -12,7 +12,7 @@ from haystack.nodes.retriever.sparse import BM25Retriever from haystack.nodes.summarizer.transformers import TransformersSummarizer from haystack.pipelines.base import Pipeline -from haystack.pipelines import ExtractiveQAPipeline, GenerativeQAPipeline, SearchSummarizationPipeline +from haystack.pipelines import ExtractiveQAPipeline, SearchSummarizationPipeline from haystack.pipelines.standard_pipelines import ( DocumentSearchPipeline, FAQPipeline, diff --git a/test/pipelines/test_pipeline.py b/test/pipelines/test_pipeline.py index 69fa27f0b9..f9737f110e 100644 --- a/test/pipelines/test_pipeline.py +++ b/test/pipelines/test_pipeline.py @@ -1,3 +1,4 @@ +from pathlib import Path import ssl import json import platform @@ -16,14 +17,15 @@ from haystack.document_stores.deepsetcloud import DeepsetCloudDocumentStore from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore from haystack.document_stores.memory import InMemoryDocumentStore +from haystack.nodes.file_classifier.file_type import FileTypeClassifier from haystack.nodes.other.join_docs import JoinDocuments from haystack.nodes.base import BaseComponent from haystack.nodes.retriever.sparse import BM25Retriever from haystack.nodes.retriever.sparse import FilterRetriever +from haystack.nodes import Shaper from haystack.pipelines import ( Pipeline, RootNode, - GenerativeQAPipeline, FAQPipeline, ExtractiveQAPipeline, SearchSummarizationPipeline, @@ -39,7 +41,7 @@ from haystack.errors import PipelineConfigError from haystack.nodes import PreProcessor, TextConverter from haystack.utils.deepsetcloud import DeepsetCloudError -from haystack import Answer +from haystack import Answer, Document from ..conftest import ( MOCK_DC, @@ -47,7 +49,6 @@ DC_API_KEY, DC_TEST_INDEX, MockDocumentStore, - MockSeq2SegGenerator, MockRetriever, MockNode, deepset_cloud_fixture, @@ -694,9 +695,6 @@ def test_generate_code_can_handle_weak_cyclic_pipelines(): @pytest.mark.unit def test_pipeline_classify_type(tmp_path): - pipe = GenerativeQAPipeline(generator=MockSeq2SegGenerator(), retriever=MockRetriever()) - assert pipe.get_type().startswith("GenerativeQAPipeline") - pipe = FAQPipeline(retriever=MockRetriever()) assert pipe.get_type().startswith("FAQPipeline") @@ -1026,8 +1024,9 @@ def dc_document_store_matcher(request: PreparedRequest) -> Tuple[bool, str]: matches = False reason = "No DeepsetCloudDocumentStore found." request_body = request.body or "" - json_body = yaml.safe_load(request_body) - components = json_body["components"] + json_body = json.loads(request_body) + config = yaml.safe_load(json_body["config"]) + components = config["components"] for component in components: if component["type"].endswith("DocumentStore"): if component["type"] == "DeepsetCloudDocumentStore": @@ -2090,6 +2089,129 @@ def test_fix_to_pipeline_execution_when_join_follows_join(): assert len(documents) == 4 # all four documents should be found +@pytest.mark.unit +def test_pipeline_execution_using_join_preserves_previous_keys(): + document_store_1 = InMemoryDocumentStore() + retriever_1 = FilterRetriever(document_store_1, scale_score=True) + dicts_1 = [{"content": "Alpha", "score": 0.552}] + document_store_1.write_documents(dicts_1) + + document_store_2 = InMemoryDocumentStore() + retriever_2 = FilterRetriever(document_store_2, scale_score=True) + dicts_2 = [{"content": "Beta", "score": 0.542}] + document_store_2.write_documents(dicts_2) + + # Create Shaper to insert "invocation_context" and "test_key" into the node_output + shaper = Shaper(func="rename", inputs={"value": "query"}, outputs=["test_key"]) + + pipeline = Pipeline() + pipeline.add_node(component=shaper, name="Shaper", inputs=["Query"]) + pipeline.add_node(component=retriever_1, name="Retriever1", inputs=["Shaper"]) + pipeline.add_node(component=retriever_2, name="Retriever2", inputs=["Shaper"]) + pipeline.add_node( + component=JoinDocuments(join_mode="concatenate"), name="Join", inputs=["Retriever1", "Retriever2"] + ) + res = pipeline.run(query="Alpha Beta Gamma Delta") + assert set(res.keys()) == { + "documents", + "labels", + "root_node", + "params", + "test_key", + "invocation_context", + "query", + "node_id", + } + assert res["test_key"] == "Alpha Beta Gamma Delta" + assert res["invocation_context"] == {"query": "Alpha Beta Gamma Delta", "test_key": "Alpha Beta Gamma Delta"} + assert len(res["documents"]) == 2 + + +@pytest.mark.unit +def test_pipeline_execution_using_join_preserves_previous_keys_three_streams(): + document_store_1 = InMemoryDocumentStore() + retriever_1 = FilterRetriever(document_store_1, scale_score=True) + dicts_1 = [{"content": "Alpha", "score": 0.552}] + document_store_1.write_documents(dicts_1) + + document_store_2 = InMemoryDocumentStore() + retriever_2 = FilterRetriever(document_store_2, scale_score=True) + dicts_2 = [{"content": "Beta", "score": 0.542}] + document_store_2.write_documents(dicts_2) + + document_store_3 = InMemoryDocumentStore() + retriever_3 = FilterRetriever(document_store_3, scale_score=True) + dicts_3 = [{"content": "Gamma", "score": 0.532}] + document_store_3.write_documents(dicts_3) + + # Create Shaper to insert "invocation_context" and "test_key" into the node_output + shaper1 = Shaper(func="rename", inputs={"value": "query"}, outputs=["test_key1"]) + shaper2 = Shaper(func="rename", inputs={"value": "query"}, outputs=["test_key2"]) + + pipeline = Pipeline() + pipeline.add_node(component=shaper1, name="Shaper1", inputs=["Query"]) + pipeline.add_node(component=shaper2, name="Shaper2", inputs=["Query"]) + pipeline.add_node(component=retriever_3, name="Retriever3", inputs=["Shaper2"]) + pipeline.add_node(component=retriever_1, name="Retriever1", inputs=["Shaper1"]) + pipeline.add_node(component=retriever_2, name="Retriever2", inputs=["Shaper1"]) + + pipeline.add_node( + component=JoinDocuments(join_mode="concatenate"), name="Join", inputs=["Retriever3", "Retriever1", "Retriever2"] + ) + res = pipeline.run(query="Alpha Beta Gamma Delta") + assert set(res.keys()) == { + "documents", + "labels", + "root_node", + "params", + "test_key1", + "test_key2", + "invocation_context", + "query", + "node_id", + } + assert res["test_key1"] == "Alpha Beta Gamma Delta" + assert res["test_key2"] == "Alpha Beta Gamma Delta" + assert res["invocation_context"] == {"query": "Alpha Beta Gamma Delta", "test_key1": "Alpha Beta Gamma Delta"} + assert len(res["documents"]) == 3 + + +@pytest.mark.unit +def test_pipeline_execution_using_join_preserves_changed_query(): + shaper1 = Shaper(func="rename", params={"value": "This is a test."}, outputs=["query"]) + shaper2 = Shaper(func="rename", params={"value": "dummy value 1"}, outputs=["dummy1"]) + shaper3 = Shaper(func="rename", params={"value": "dummy value 2"}, outputs=["dummy2"]) + pipeline = Pipeline() + pipeline.add_node(component=shaper1, name="Shaper1", inputs=["Query"]) + pipeline.add_node(component=shaper2, name="DummyNode1", inputs=["Query"]) + pipeline.add_node(component=JoinDocuments(join_mode="concatenate"), name="Join", inputs=["Shaper1", "DummyNode1"]) + pipeline.add_node(component=shaper3, name="DummyNode2", inputs=["Join"]) + res = pipeline.run(query="Alpha Beta Gamma Delta", debug=True, documents=[Document(content="Test Document")]) + assert res["_debug"]["Shaper1"]["input"]["query"] == "Alpha Beta Gamma Delta" + assert res["_debug"]["Join"]["input"]["query"] == "This is a test." + assert res["_debug"]["DummyNode2"]["input"]["query"] == "This is a test." + assert res["query"] == "This is a test." + + +@pytest.mark.unit +@pytest.mark.parametrize("file_type", ["csv", "xml"]) +def test_pipeline_execution_can_handle_unknown_edge_for_classifier(file_type: str) -> None: + """Tests running a classifier against an unexpected file type. + + We need to route not expected file types against a dead end node. + Simply returning "None" for a classification does not work, since + the pipeline will raise an error, trying to route the result to the next node + here: https://github.com/deepset-ai/haystack/blob/b45ecb355636c3227185e97ee595006c06d17470/haystack/pipelines/base.py#L573 + + See fix pr: https://github.com/deepset-ai/haystack/pull/7589 + """ + classifier = FileTypeClassifier(raise_on_error=False) + pipeline = Pipeline() + pipeline.add_node(component=classifier, name="FileTypeClassifier", inputs=["File"]) + res = pipeline.run_batch(file_paths=[f"./test.{file_type}"]) + assert res["file_paths"] == [Path(f"./test.{file_type}")] + + @pytest.mark.unit def test_update_config_hash(): fake_configs = { diff --git a/test/pipelines/test_standard_pipelines.py b/test/pipelines/test_standard_pipelines.py index 2d6523d7a6..d8512c2fbd 100644 --- a/test/pipelines/test_standard_pipelines.py +++ b/test/pipelines/test_standard_pipelines.py @@ -79,7 +79,7 @@ def test_webqa_pipeline(): search_key = os.environ.get("SERPERDEV_API_KEY") openai_key = os.environ.get("OPENAI_API_KEY") pn = PromptNode( - "text-davinci-003", + "gpt-3.5-turbo-instruct", api_key=openai_key, max_length=256, default_prompt_template="question-answering-with-document-scores", diff --git a/test/preview/components/audio/__init__.py b/test/preview/components/audio/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/test/preview/components/audio/test_whisper_local.py b/test/preview/components/audio/test_whisper_local.py deleted file mode 100644 index df23b12018..0000000000 --- a/test/preview/components/audio/test_whisper_local.py +++ /dev/null @@ -1,170 +0,0 @@ -import sys -from pathlib import Path -from unittest.mock import patch, MagicMock - -import pytest -import torch - -from haystack.preview.dataclasses import Document -from haystack.preview.components.audio import LocalWhisperTranscriber - - -SAMPLES_PATH = Path(__file__).parent.parent.parent / "test_files" - - -class TestLocalWhisperTranscriber: - @pytest.mark.unit - def test_init(self): - transcriber = LocalWhisperTranscriber( - model_name_or_path="large-v2" - ) # Doesn't matter if it's huge, the model is not loaded in init. - assert transcriber.model_name == "large-v2" - assert transcriber.device == torch.device("cpu") - assert transcriber._model is None - - @pytest.mark.unit - def test_init_wrong_model(self): - with pytest.raises(ValueError, match="Model name 'whisper-1' not recognized"): - LocalWhisperTranscriber(model_name_or_path="whisper-1") - - @pytest.mark.unit - def test_to_dict(self): - transcriber = LocalWhisperTranscriber() - data = transcriber.to_dict() - assert data == { - "type": "haystack.preview.components.audio.whisper_local.LocalWhisperTranscriber", - "init_parameters": {"model_name_or_path": "large", "device": "cpu", "whisper_params": {}}, - } - - @pytest.mark.unit - def test_to_dict_with_custom_init_parameters(self): - transcriber = LocalWhisperTranscriber( - model_name_or_path="tiny", - device="cuda", - whisper_params={"return_segments": True, "temperature": [0.1, 0.6, 0.8]}, - ) - data = transcriber.to_dict() - assert data == { - "type": "haystack.preview.components.audio.whisper_local.LocalWhisperTranscriber", - "init_parameters": { - "model_name_or_path": "tiny", - "device": "cuda", - "whisper_params": {"return_segments": True, "temperature": [0.1, 0.6, 0.8]}, - }, - } - - @pytest.mark.unit - def test_warmup(self): - with patch("haystack.preview.components.audio.whisper_local.whisper") as mocked_whisper: - transcriber = LocalWhisperTranscriber(model_name_or_path="large-v2") - mocked_whisper.load_model.assert_not_called() - transcriber.warm_up() - mocked_whisper.load_model.assert_called_once_with("large-v2", device=torch.device(type="cpu")) - - @pytest.mark.unit - def test_warmup_doesnt_reload(self): - with patch("haystack.preview.components.audio.whisper_local.whisper") as mocked_whisper: - transcriber = LocalWhisperTranscriber(model_name_or_path="large-v2") - transcriber.warm_up() - transcriber.warm_up() - mocked_whisper.load_model.assert_called_once() - - @pytest.mark.unit - def test_run_with_path(self): - comp = LocalWhisperTranscriber(model_name_or_path="large-v2") - comp._model = MagicMock() - comp._model.transcribe.return_value = { - "text": "test transcription", - "other_metadata": ["other", "meta", "data"], - } - results = comp.run(audio_files=[SAMPLES_PATH / "audio" / "this is the content of the document.wav"]) - expected = Document( - content="test transcription", - meta={ - "audio_file": SAMPLES_PATH / "audio" / "this is the content of the document.wav", - "other_metadata": ["other", "meta", "data"], - }, - ) - assert results["documents"] == [expected] - - @pytest.mark.unit - def test_run_with_str(self): - comp = LocalWhisperTranscriber(model_name_or_path="large-v2") - comp._model = MagicMock() - comp._model.transcribe.return_value = { - "text": "test transcription", - "other_metadata": ["other", "meta", "data"], - } - results = comp.run( - audio_files=[str((SAMPLES_PATH / "audio" / "this is the content of the document.wav").absolute())] - ) - expected = Document( - content="test transcription", - meta={ - "audio_file": str((SAMPLES_PATH / "audio" / "this is the content of the document.wav").absolute()), - "other_metadata": ["other", "meta", "data"], - }, - ) - assert results["documents"] == [expected] - - @pytest.mark.unit - def test_transcribe(self): - comp = LocalWhisperTranscriber(model_name_or_path="large-v2") - comp._model = MagicMock() - comp._model.transcribe.return_value = { - "text": "test transcription", - "other_metadata": ["other", "meta", "data"], - } - results = comp.transcribe(audio_files=[SAMPLES_PATH / "audio" / "this is the content of the document.wav"]) - expected = Document( - content="test transcription", - meta={ - "audio_file": SAMPLES_PATH / "audio" / "this is the content of the document.wav", - "other_metadata": ["other", "meta", "data"], - }, - ) - assert results == [expected] - - @pytest.mark.unit - def test_transcribe_stream(self): - comp = LocalWhisperTranscriber(model_name_or_path="large-v2") - comp._model = MagicMock() - comp._model.transcribe.return_value = { - "text": "test transcription", - "other_metadata": ["other", "meta", "data"], - } - results = comp.transcribe( - audio_files=[open(SAMPLES_PATH / "audio" / "this is the content of the document.wav", "rb")] - ) - expected = Document( - content="test transcription", - meta={"audio_file": "<>", "other_metadata": ["other", "meta", "data"]}, - ) - assert results == [expected] - - @pytest.mark.integration - @pytest.mark.skipif(sys.platform in ["win32", "cygwin"], reason="ffmpeg not installed on Windows CI") - def test_whisper_local_transcriber(self, preview_samples_path): - comp = LocalWhisperTranscriber(model_name_or_path="medium", whisper_params={"language": "english"}) - comp.warm_up() - output = comp.run( - audio_files=[ - preview_samples_path / "audio" / "this is the content of the document.wav", - str((preview_samples_path / "audio" / "the context for this answer is here.wav").absolute()), - open(preview_samples_path / "audio" / "answer.wav", "rb"), - ] - ) - docs = output["documents"] - assert len(docs) == 3 - - assert docs[0].content.strip().lower() == "this is the content of the document." - assert preview_samples_path / "audio" / "this is the content of the document.wav" == docs[0].meta["audio_file"] - - assert docs[1].content.strip().lower() == "the context for this answer is here." - assert ( - str((preview_samples_path / "audio" / "the context for this answer is here.wav").absolute()) - == docs[1].meta["audio_file"] - ) - - assert docs[2].content.strip().lower() == "answer." - assert docs[2].meta["audio_file"] == "<>" diff --git a/test/preview/components/audio/test_whisper_remote.py b/test/preview/components/audio/test_whisper_remote.py deleted file mode 100644 index df6b8067f5..0000000000 --- a/test/preview/components/audio/test_whisper_remote.py +++ /dev/null @@ -1,254 +0,0 @@ -import os -from unittest.mock import patch -from pathlib import Path - -import openai -import pytest -from openai.util import convert_to_openai_object - -from haystack.preview.components.audio.whisper_remote import RemoteWhisperTranscriber -from haystack.preview.dataclasses import ByteStream - - -def mock_openai_response(response_format="json", **kwargs) -> openai.openai_object.OpenAIObject: - if response_format == "json": - dict_response = {"text": "test transcription"} - # Currently only "json" is supported. - else: - dict_response = {} - - return convert_to_openai_object(dict_response) - - -class TestRemoteWhisperTranscriber: - @pytest.mark.unit - def test_init_no_key(self, monkeypatch): - openai.api_key = None - monkeypatch.delenv("OPENAI_API_KEY", raising=False) - error_msg = "RemoteWhisperTranscriber expects an OpenAI API key." - with pytest.raises(ValueError, match=error_msg): - RemoteWhisperTranscriber(api_key=None) - - def test_init_key_env_var(self, monkeypatch): - openai.api_key = None - monkeypatch.setenv("OPENAI_API_KEY", "test_api_key") - RemoteWhisperTranscriber(api_key=None) - assert openai.api_key == "test_api_key" - - def test_init_key_module_env_and_global_var(self, monkeypatch): - monkeypatch.setenv("OPENAI_API_KEY", "test_api_key_2") - openai.api_key = "test_api_key_1" - RemoteWhisperTranscriber(api_key=None) - # The module global variable takes preference - assert openai.api_key == "test_api_key_1" - - @pytest.mark.unit - def test_init_default(self): - transcriber = RemoteWhisperTranscriber(api_key="test_api_key") - - assert openai.api_key == "test_api_key" - assert transcriber.model_name == "whisper-1" - assert transcriber.organization is None - assert transcriber.api_base_url == "https://api.openai.com/v1" - assert transcriber.whisper_params == {"response_format": "json"} - - @pytest.mark.unit - def test_init_custom_parameters(self): - transcriber = RemoteWhisperTranscriber( - api_key="test_api_key", - model_name="whisper-1", - organization="test-org", - api_base_url="test_api_url", - language="en", - prompt="test-prompt", - response_format="json", - temperature="0.5", - ) - - assert openai.api_key == "test_api_key" - assert transcriber.model_name == "whisper-1" - assert transcriber.organization == "test-org" - assert transcriber.api_base_url == "test_api_url" - assert transcriber.whisper_params == { - "language": "en", - "prompt": "test-prompt", - "response_format": "json", - "temperature": "0.5", - } - - @pytest.mark.unit - def test_to_dict_default_parameters(self): - transcriber = RemoteWhisperTranscriber(api_key="test_api_key") - data = transcriber.to_dict() - assert data == { - "type": "haystack.preview.components.audio.whisper_remote.RemoteWhisperTranscriber", - "init_parameters": { - "model_name": "whisper-1", - "api_base_url": "https://api.openai.com/v1", - "organization": None, - "response_format": "json", - }, - } - - @pytest.mark.unit - def test_to_dict_with_custom_init_parameters(self): - transcriber = RemoteWhisperTranscriber( - api_key="test_api_key", - model_name="whisper-1", - organization="test-org", - api_base_url="test_api_url", - language="en", - prompt="test-prompt", - response_format="json", - temperature="0.5", - ) - data = transcriber.to_dict() - assert data == { - "type": "haystack.preview.components.audio.whisper_remote.RemoteWhisperTranscriber", - "init_parameters": { - "model_name": "whisper-1", - "organization": "test-org", - "api_base_url": "test_api_url", - "language": "en", - "prompt": "test-prompt", - "response_format": "json", - "temperature": "0.5", - }, - } - - def test_from_dict_with_defualt_parameters(self, monkeypatch): - monkeypatch.setenv("OPENAI_API_KEY", "test_api_key") - - data = { - "type": "haystack.preview.components.audio.whisper_remote.RemoteWhisperTranscriber", - "init_parameters": { - "model_name": "whisper-1", - "api_base_url": "https://api.openai.com/v1", - "organization": None, - "response_format": "json", - }, - } - - transcriber = RemoteWhisperTranscriber.from_dict(data) - - assert openai.api_key == "test_api_key" - assert transcriber.model_name == "whisper-1" - assert transcriber.organization is None - assert transcriber.api_base_url == "https://api.openai.com/v1" - assert transcriber.whisper_params == {"response_format": "json"} - - def test_from_dict_with_custom_init_parameters(self, monkeypatch): - monkeypatch.setenv("OPENAI_API_KEY", "test_api_key") - - data = { - "type": "haystack.preview.components.audio.whisper_remote.RemoteWhisperTranscriber", - "init_parameters": { - "model_name": "whisper-1", - "organization": "test-org", - "api_base_url": "test_api_url", - "language": "en", - "prompt": "test-prompt", - "response_format": "json", - "temperature": "0.5", - }, - } - transcriber = RemoteWhisperTranscriber.from_dict(data) - - assert openai.api_key == "test_api_key" - assert transcriber.model_name == "whisper-1" - assert transcriber.organization == "test-org" - assert transcriber.api_base_url == "test_api_url" - assert transcriber.whisper_params == { - "language": "en", - "prompt": "test-prompt", - "response_format": "json", - "temperature": "0.5", - } - - def test_from_dict_with_defualt_parameters_no_env_var(self, monkeypatch): - openai.api_key = None - monkeypatch.delenv("OPENAI_API_KEY", raising=False) - - data = { - "type": "haystack.preview.components.audio.whisper_remote.RemoteWhisperTranscriber", - "init_parameters": { - "model_name": "whisper-1", - "api_base_url": "https://api.openai.com/v1", - "organization": None, - "response_format": "json", - }, - } - - with pytest.raises(ValueError, match="RemoteWhisperTranscriber expects an OpenAI API key."): - RemoteWhisperTranscriber.from_dict(data) - - @pytest.mark.unit - def test_run_str(self, preview_samples_path): - with patch("haystack.preview.components.audio.whisper_remote.openai.Audio") as openai_audio_patch: - model = "whisper-1" - file_path = str(preview_samples_path / "audio" / "this is the content of the document.wav") - openai_audio_patch.transcribe.side_effect = mock_openai_response - - transcriber = RemoteWhisperTranscriber(api_key="test_api_key", model_name=model, response_format="json") - result = transcriber.run(sources=[file_path]) - - assert result["documents"][0].content == "test transcription" - assert result["documents"][0].meta["file_path"] == file_path - - @pytest.mark.unit - def test_run_path(self, preview_samples_path): - with patch("haystack.preview.components.audio.whisper_remote.openai.Audio") as openai_audio_patch: - model = "whisper-1" - file_path = preview_samples_path / "audio" / "this is the content of the document.wav" - openai_audio_patch.transcribe.side_effect = mock_openai_response - - transcriber = RemoteWhisperTranscriber(api_key="test_api_key", model_name=model, response_format="json") - result = transcriber.run(sources=[file_path]) - - assert result["documents"][0].content == "test transcription" - assert result["documents"][0].meta["file_path"] == file_path - - @pytest.mark.unit - def test_run_bytestream(self, preview_samples_path): - with patch("haystack.preview.components.audio.whisper_remote.openai.Audio") as openai_audio_patch: - model = "whisper-1" - file_path = preview_samples_path / "audio" / "this is the content of the document.wav" - openai_audio_patch.transcribe.side_effect = mock_openai_response - - transcriber = RemoteWhisperTranscriber(api_key="test_api_key", model_name=model, response_format="json") - with open(file_path, "rb") as audio_stream: - byte_stream = audio_stream.read() - audio_file = ByteStream(byte_stream, metadata={"file_path": str(file_path.absolute())}) - - result = transcriber.run(sources=[audio_file]) - - assert result["documents"][0].content == "test transcription" - assert result["documents"][0].meta["file_path"] == str(file_path.absolute()) - - @pytest.mark.skipif( - not os.environ.get("OPENAI_API_KEY", None), - reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", - ) - @pytest.mark.integration - def test_whisper_remote_transcriber(self, preview_samples_path): - transcriber = RemoteWhisperTranscriber(api_key=os.environ.get("OPENAI_API_KEY")) - - paths = [ - preview_samples_path / "audio" / "this is the content of the document.wav", - str(preview_samples_path / "audio" / "the context for this answer is here.wav"), - ByteStream.from_file_path(preview_samples_path / "audio" / "answer.wav"), - ] - - output = transcriber.run(sources=paths) - - docs = output["documents"] - assert len(docs) == 3 - assert docs[0].content.strip().lower() == "this is the content of the document." - assert preview_samples_path / "audio" / "this is the content of the document.wav" == docs[0].meta["file_path"] - - assert docs[1].content.strip().lower() == "the context for this answer is here." - assert ( - str(preview_samples_path / "audio" / "the context for this answer is here.wav") == docs[1].meta["file_path"] - ) - - assert docs[2].content.strip().lower() == "answer." diff --git a/test/preview/components/builders/__init__.py b/test/preview/components/builders/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/test/preview/components/builders/test_answer_builder.py b/test/preview/components/builders/test_answer_builder.py deleted file mode 100644 index 7ce6d56e7b..0000000000 --- a/test/preview/components/builders/test_answer_builder.py +++ /dev/null @@ -1,154 +0,0 @@ -import logging - -import pytest - -from haystack.preview import GeneratedAnswer, Document -from haystack.preview.components.builders.answer_builder import AnswerBuilder - - -class TestAnswerBuilder: - @pytest.mark.unit - def test_run_unmatching_input_len(self): - component = AnswerBuilder() - with pytest.raises(ValueError): - component.run(query="query", replies=["reply1"], metadata=[{"test": "meta"}, {"test": "meta2"}]) - - @pytest.mark.unit - def test_run_without_meta(self): - component = AnswerBuilder() - output = component.run(query="query", replies=["reply1"]) - answers = output["answers"] - assert answers[0].data == "reply1" - assert answers[0].metadata == {} - assert answers[0].query == "query" - assert answers[0].documents == [] - assert isinstance(answers[0], GeneratedAnswer) - - @pytest.mark.unit - def test_run_meta_is_an_empty_list(self): - component = AnswerBuilder() - output = component.run(query="query", replies=["reply1"], metadata=[]) - answers = output["answers"] - assert answers[0].data == "reply1" - assert answers[0].metadata == {} - assert answers[0].query == "query" - assert answers[0].documents == [] - assert isinstance(answers[0], GeneratedAnswer) - - def test_run_without_pattern(self): - component = AnswerBuilder() - output = component.run(query="test query", replies=["Answer: AnswerString"], metadata=[{}]) - answers = output["answers"] - assert len(answers) == 1 - assert answers[0].data == "Answer: AnswerString" - assert answers[0].metadata == {} - assert answers[0].query == "test query" - assert answers[0].documents == [] - assert isinstance(answers[0], GeneratedAnswer) - - def test_run_with_pattern_with_capturing_group(self): - component = AnswerBuilder(pattern=r"Answer: (.*)") - output = component.run(query="test query", replies=["Answer: AnswerString"], metadata=[{}]) - answers = output["answers"] - assert len(answers) == 1 - assert answers[0].data == "AnswerString" - assert answers[0].metadata == {} - assert answers[0].query == "test query" - assert answers[0].documents == [] - assert isinstance(answers[0], GeneratedAnswer) - - def test_run_with_pattern_without_capturing_group(self): - component = AnswerBuilder(pattern=r"'.*'") - output = component.run(query="test query", replies=["Answer: 'AnswerString'"], metadata=[{}]) - answers = output["answers"] - assert len(answers) == 1 - assert answers[0].data == "'AnswerString'" - assert answers[0].metadata == {} - assert answers[0].query == "test query" - assert answers[0].documents == [] - assert isinstance(answers[0], GeneratedAnswer) - - def test_run_with_pattern_with_more_than_one_capturing_group(self): - with pytest.raises(ValueError, match="contains multiple capture groups"): - AnswerBuilder(pattern=r"Answer: (.*), (.*)") - - def test_run_with_pattern_set_at_runtime(self): - component = AnswerBuilder(pattern="unused pattern") - output = component.run( - query="test query", replies=["Answer: AnswerString"], metadata=[{}], pattern=r"Answer: (.*)" - ) - answers = output["answers"] - assert len(answers) == 1 - assert answers[0].data == "AnswerString" - assert answers[0].metadata == {} - assert answers[0].query == "test query" - assert answers[0].documents == [] - assert isinstance(answers[0], GeneratedAnswer) - - def test_run_with_documents_without_reference_pattern(self): - component = AnswerBuilder() - output = component.run( - query="test query", - replies=["Answer: AnswerString"], - metadata=[{}], - documents=[Document(content="test doc 1"), Document(content="test doc 2")], - ) - answers = output["answers"] - assert len(answers) == 1 - assert answers[0].data == "Answer: AnswerString" - assert answers[0].metadata == {} - assert answers[0].query == "test query" - assert len(answers[0].documents) == 2 - assert answers[0].documents[0].content == "test doc 1" - assert answers[0].documents[1].content == "test doc 2" - - def test_run_with_documents_with_reference_pattern(self): - component = AnswerBuilder(reference_pattern="\\[(\\d+)\\]") - output = component.run( - query="test query", - replies=["Answer: AnswerString[2]"], - metadata=[{}], - documents=[Document(content="test doc 1"), Document(content="test doc 2")], - ) - answers = output["answers"] - assert len(answers) == 1 - assert answers[0].data == "Answer: AnswerString[2]" - assert answers[0].metadata == {} - assert answers[0].query == "test query" - assert len(answers[0].documents) == 1 - assert answers[0].documents[0].content == "test doc 2" - - def test_run_with_documents_with_reference_pattern_and_no_match(self, caplog): - component = AnswerBuilder(reference_pattern="\\[(\\d+)\\]") - with caplog.at_level(logging.WARNING): - output = component.run( - query="test query", - replies=["Answer: AnswerString[3]"], - metadata=[{}], - documents=[Document(content="test doc 1"), Document(content="test doc 2")], - ) - answers = output["answers"] - assert len(answers) == 1 - assert answers[0].data == "Answer: AnswerString[3]" - assert answers[0].metadata == {} - assert answers[0].query == "test query" - assert len(answers[0].documents) == 0 - assert "Document index '3' referenced in Generator output is out of range." in caplog.text - - def test_run_with_reference_pattern_set_at_runtime(self): - component = AnswerBuilder(reference_pattern="unused pattern") - output = component.run( - query="test query", - replies=["Answer: AnswerString[2][3]"], - metadata=[{}], - documents=[Document(content="test doc 1"), Document(content="test doc 2"), Document(content="test doc 3")], - reference_pattern="\\[(\\d+)\\]", - ) - answers = output["answers"] - assert len(answers) == 1 - assert answers[0].data == "Answer: AnswerString[2][3]" - assert answers[0].metadata == {} - assert answers[0].query == "test query" - assert len(answers[0].documents) == 2 - assert answers[0].documents[0].content == "test doc 2" - assert answers[0].documents[1].content == "test doc 3" diff --git a/test/preview/components/builders/test_dynamic_prompt_builder.py b/test/preview/components/builders/test_dynamic_prompt_builder.py deleted file mode 100644 index 007040da1e..0000000000 --- a/test/preview/components/builders/test_dynamic_prompt_builder.py +++ /dev/null @@ -1,154 +0,0 @@ -from typing import List, Union - -import pytest -from jinja2 import TemplateSyntaxError - -from haystack.preview.components.builders.dynamic_prompt_builder import DynamicPromptBuilder -from haystack.preview.dataclasses import ChatMessage - - -class TestDynamicPromptBuilder: - def test_initialization_chat_on(self): - runtime_variables = ["var1", "var2", "var3"] - builder = DynamicPromptBuilder(runtime_variables, chat_mode=True) - assert builder.runtime_variables == runtime_variables - assert builder.chat_mode - - # regardless of the chat mode - # we have inputs that contain: prompt_source, template_variables + runtime_variables - expected_keys = set(runtime_variables + ["prompt_source", "template_variables"]) - assert set(builder.__canals_input__.keys()) == expected_keys - - # response is always prompt regardless of chat mode - assert set(builder.__canals_output__.keys()) == {"prompt"} - - # prompt_source is a list of ChatMessage or a string - assert builder.__canals_input__["prompt_source"].type == Union[List[ChatMessage], str] - - # output is always prompt, but the type is different depending on the chat mode - assert builder.__canals_output__["prompt"].type == List[ChatMessage] - - def test_initialization_chat_off(self): - runtime_variables = ["var1", "var2"] - builder = DynamicPromptBuilder(runtime_variables, False) - assert builder.runtime_variables == runtime_variables - assert not builder.chat_mode - - # regardless of the chat mode - # we have inputs that contain: prompt_source, template_variables + runtime_variables - expected_keys = set(runtime_variables + ["prompt_source", "template_variables"]) - assert set(builder.__canals_input__.keys()) == expected_keys - - # response is always prompt regardless of chat mode - assert set(builder.__canals_output__.keys()) == {"prompt"} - - # prompt_source is a list of ChatMessage or a string - assert builder.__canals_input__["prompt_source"].type == Union[List[ChatMessage], str] - - # output is always prompt, but the type is different depending on the chat mode - assert builder.__canals_output__["prompt"].type == str - - def test_to_dict_method_returns_expected_dictionary(self): - runtime_variables = ["var1", "var2", "var3"] - chat_mode = True - builder = DynamicPromptBuilder(runtime_variables, chat_mode) - expected_dict = { - "type": "haystack.preview.components.builders.dynamic_prompt_builder.DynamicPromptBuilder", - "init_parameters": {"runtime_variables": runtime_variables, "chat_mode": chat_mode}, - } - assert builder.to_dict() == expected_dict - - def test_processing_a_simple_template_with_provided_variables(self): - runtime_variables = ["var1", "var2", "var3"] - chat_mode = True - - builder = DynamicPromptBuilder(runtime_variables, chat_mode) - - template = "Hello, {{ name }}!" - template_variables = {"name": "John"} - expected_result = "Hello, John!" - - assert builder._process_simple_template(template, template_variables) == expected_result - - def test_processing_a_simple_template_with_invalid_template(self): - runtime_variables = ["var1", "var2", "var3"] - chat_mode = True - - builder = DynamicPromptBuilder(runtime_variables, chat_mode) - - template = "Hello, {{ name }!" - template_variables = {"name": "John"} - with pytest.raises(TemplateSyntaxError): - builder._process_simple_template(template, template_variables) - - def test_processing_a_simple_template_with_missing_variables(self): - runtime_variables = ["var1", "var2", "var3"] - - builder = DynamicPromptBuilder(runtime_variables, False) - - with pytest.raises(ValueError): - builder._process_simple_template("Hello, {{ name }}!", {}) - - def test_non_empty_chat_messages(self): - prompt_builder = DynamicPromptBuilder(runtime_variables=["documents"], chat_mode=True) - prompt_source = [ChatMessage.from_system(content="Hello"), ChatMessage.from_user(content="Hello, {{ who }}!")] - template_variables = {"who": "World"} - - result = prompt_builder._process_chat_messages(prompt_source, template_variables) - - assert result == [ChatMessage.from_system(content="Hello"), ChatMessage.from_user(content="Hello, World!")] - - def test_single_chat_message(self): - prompt_builder = DynamicPromptBuilder(runtime_variables=["documents"], chat_mode=True) - prompt_source = [ChatMessage.from_user(content="Hello, {{ who }}!")] - template_variables = {"who": "World"} - - result = prompt_builder._process_chat_messages(prompt_source, template_variables) - - assert result == [ChatMessage.from_user(content="Hello, World!")] - - def test_empty_chat_message_list(self): - prompt_builder = DynamicPromptBuilder(runtime_variables=["documents"], chat_mode=True) - - with pytest.raises(ValueError): - prompt_builder._process_chat_messages(prompt_source=[], template_variables={}) - - def test_chat_message_list_with_mixed_object_list(self): - prompt_builder = DynamicPromptBuilder(runtime_variables=["documents"], chat_mode=True) - - with pytest.raises(ValueError): - prompt_builder._process_chat_messages( - prompt_source=[ChatMessage.from_user("Hello"), "there world"], template_variables={} - ) - - def test_chat_message_list_with_missing_variables(self): - prompt_builder = DynamicPromptBuilder(runtime_variables=["documents"], chat_mode=True) - prompt_source = [ChatMessage.from_user(content="Hello, {{ who }}!")] - - # Call the _process_chat_messages method and expect a ValueError - with pytest.raises(ValueError): - prompt_builder._process_chat_messages(prompt_source, template_variables={}) - - def test_missing_template_variables(self): - prompt_builder = DynamicPromptBuilder(runtime_variables=["documents"]) - - # missing template variable city - with pytest.raises(ValueError): - prompt_builder._validate_template("Hello, I'm {{ name }}, and I live in {{ city }}.", {"name"}) - - # missing template variable name - with pytest.raises(ValueError): - prompt_builder._validate_template("Hello, I'm {{ name }}, and I live in {{ city }}.", {"city"}) - - # completely unknown template variable - with pytest.raises(ValueError): - prompt_builder._validate_template("Hello, I'm {{ name }}, and I live in {{ city }}.", {"age"}) - - def test_provided_template_variables(self): - prompt_builder = DynamicPromptBuilder(runtime_variables=["documents"]) - - # both variables are provided - prompt_builder._validate_template("Hello, I'm {{ name }}, and I live in {{ city }}.", {"name", "city"}) - - # provided variables are a superset of the required variables - prompt_builder._validate_template("Hello, I'm {{ name }}, and I live in {{ city }}.", {"name", "city", "age"}) diff --git a/test/preview/components/builders/test_prompt_builder.py b/test/preview/components/builders/test_prompt_builder.py deleted file mode 100644 index e43e99bb92..0000000000 --- a/test/preview/components/builders/test_prompt_builder.py +++ /dev/null @@ -1,40 +0,0 @@ -import pytest - -from haystack.preview.components.builders.prompt_builder import PromptBuilder - - -@pytest.mark.unit -def test_init(): - builder = PromptBuilder(template="This is a {{ variable }}") - assert builder._template_string == "This is a {{ variable }}" - - -@pytest.mark.unit -def test_to_dict(): - builder = PromptBuilder(template="This is a {{ variable }}") - res = builder.to_dict() - assert res == { - "type": "haystack.preview.components.builders.prompt_builder.PromptBuilder", - "init_parameters": {"template": "This is a {{ variable }}"}, - } - - -@pytest.mark.unit -def test_run(): - builder = PromptBuilder(template="This is a {{ variable }}") - res = builder.run(variable="test") - assert res == {"prompt": "This is a test"} - - -@pytest.mark.unit -def test_run_without_input(): - builder = PromptBuilder(template="This is a template without input") - res = builder.run() - assert res == {"prompt": "This is a template without input"} - - -@pytest.mark.unit -def test_run_with_missing_input(): - builder = PromptBuilder(template="This is a {{ variable }}") - res = builder.run() - assert res == {"prompt": "This is a "} diff --git a/test/preview/components/caching/test_url_cache_checker.py b/test/preview/components/caching/test_url_cache_checker.py deleted file mode 100644 index 1a9487f045..0000000000 --- a/test/preview/components/caching/test_url_cache_checker.py +++ /dev/null @@ -1,95 +0,0 @@ -import pytest - -from haystack.preview import Document, DeserializationError -from haystack.preview.testing.factory import document_store_class -from haystack.preview.document_stores.in_memory import InMemoryDocumentStore -from haystack.preview.components.caching.url_cache_checker import UrlCacheChecker - - -class TestUrlCacheChecker: - @pytest.mark.unit - def test_to_dict(self): - mocked_docstore_class = document_store_class("MockedDocumentStore") - component = UrlCacheChecker(document_store=mocked_docstore_class()) - data = component.to_dict() - assert data == { - "type": "haystack.preview.components.caching.url_cache_checker.UrlCacheChecker", - "init_parameters": { - "document_store": { - "type": "haystack.preview.testing.factory.MockedDocumentStore", - "init_parameters": {}, - }, - "url_field": "url", - }, - } - - @pytest.mark.unit - def test_to_dict_with_custom_init_parameters(self): - mocked_docstore_class = document_store_class("MockedDocumentStore") - component = UrlCacheChecker(document_store=mocked_docstore_class(), url_field="my_url_field") - data = component.to_dict() - assert data == { - "type": "haystack.preview.components.caching.url_cache_checker.UrlCacheChecker", - "init_parameters": { - "document_store": { - "type": "haystack.preview.testing.factory.MockedDocumentStore", - "init_parameters": {}, - }, - "url_field": "my_url_field", - }, - } - - @pytest.mark.unit - def test_from_dict(self): - mocked_docstore_class = document_store_class("MockedDocumentStore") - data = { - "type": "haystack.preview.components.caching.url_cache_checker.UrlCacheChecker", - "init_parameters": { - "document_store": { - "type": "haystack.preview.testing.factory.MockedDocumentStore", - "init_parameters": {}, - }, - "url_field": "my_url_field", - }, - } - component = UrlCacheChecker.from_dict(data) - assert isinstance(component.document_store, mocked_docstore_class) - assert component.url_field == "my_url_field" - - @pytest.mark.unit - def test_from_dict_without_docstore(self): - data = {"type": "haystack.preview.components.caching.url_cache_checker.UrlCacheChecker", "init_parameters": {}} - with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"): - UrlCacheChecker.from_dict(data) - - @pytest.mark.unit - def test_from_dict_without_docstore_type(self): - data = { - "type": "haystack.preview.components.caching.url_cache_checker.UrlCacheChecker", - "init_parameters": {"document_store": {"init_parameters": {}}}, - } - with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"): - UrlCacheChecker.from_dict(data) - - @pytest.mark.unit - def test_from_dict_nonexisting_docstore(self): - data = { - "type": "haystack.preview.components.caching.url_cache_checker.UrlCacheChecker", - "init_parameters": {"document_store": {"type": "NonexistingDocumentStore", "init_parameters": {}}}, - } - with pytest.raises(DeserializationError, match="DocumentStore of type 'NonexistingDocumentStore' not found."): - UrlCacheChecker.from_dict(data) - - @pytest.mark.unit - def test_run(self): - docstore = InMemoryDocumentStore() - documents = [ - Document(content="doc1", meta={"url": "https://example.com/1"}), - Document(content="doc2", meta={"url": "https://example.com/2"}), - Document(content="doc3", meta={"url": "https://example.com/1"}), - Document(content="doc4", meta={"url": "https://example.com/2"}), - ] - docstore.write_documents(documents) - checker = UrlCacheChecker(docstore) - results = checker.run(urls=["https://example.com/1", "https://example.com/5"]) - assert results == {"hits": [documents[0], documents[2]], "misses": ["https://example.com/5"]} diff --git a/test/preview/components/classifiers/test_document_language_classifier.py b/test/preview/components/classifiers/test_document_language_classifier.py deleted file mode 100644 index 53214b3633..0000000000 --- a/test/preview/components/classifiers/test_document_language_classifier.py +++ /dev/null @@ -1,52 +0,0 @@ -import logging -import pytest - -from haystack.preview import Document -from haystack.preview.components.classifiers import DocumentLanguageClassifier - - -class TestDocumentLanguageClassifier: - @pytest.mark.unit - def test_init(self): - component = DocumentLanguageClassifier() - assert component.languages == ["en"] - - @pytest.mark.unit - def test_non_document_input(self): - with pytest.raises(TypeError, match="DocumentLanguageClassifier expects a list of Document as input."): - classifier = DocumentLanguageClassifier() - classifier.run(documents="This is an english sentence.") - - @pytest.mark.unit - def test_single_document(self): - with pytest.raises(TypeError, match="DocumentLanguageClassifier expects a list of Document as input."): - classifier = DocumentLanguageClassifier() - classifier.run(documents=Document(content="This is an english sentence.")) - - @pytest.mark.unit - def test_empty_list(self): - classifier = DocumentLanguageClassifier() - result = classifier.run(documents=[]) - assert result == {"documents": []} - - @pytest.mark.unit - def test_detect_language(self): - classifier = DocumentLanguageClassifier() - detected_language = classifier.detect_language(Document(content="This is an english sentence.")) - assert detected_language == "en" - - @pytest.mark.unit - def test_classify_as_en_and_unmatched(self): - classifier = DocumentLanguageClassifier() - english_document = Document(content="This is an english sentence.") - german_document = Document(content="Ein deutscher Satz ohne Verb.") - result = classifier.run(documents=[english_document, german_document]) - assert result["documents"][0].meta["language"] == "en" - assert result["documents"][1].meta["language"] == "unmatched" - - @pytest.mark.unit - def test_warning_if_no_language_detected(self, caplog): - with caplog.at_level(logging.WARNING): - classifier = DocumentLanguageClassifier() - classifier.run(documents=[Document(content=".")]) - assert "Langdetect cannot detect the language of Document with id" in caplog.text diff --git a/test/preview/components/converters/__init__.py b/test/preview/components/converters/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/test/preview/components/converters/test_azure_ocr_doc_converter.py b/test/preview/components/converters/test_azure_ocr_doc_converter.py deleted file mode 100644 index 83c5075c4b..0000000000 --- a/test/preview/components/converters/test_azure_ocr_doc_converter.py +++ /dev/null @@ -1,88 +0,0 @@ -import os -from unittest.mock import patch, Mock - -import pytest - -from haystack.preview.components.converters.azure import AzureOCRDocumentConverter - - -class TestAzureOCRDocumentConverter: - @pytest.mark.unit - def test_init_fail_wo_api_key(self, monkeypatch): - monkeypatch.delenv("AZURE_AI_API_KEY", raising=False) - with pytest.raises(ValueError, match="AzureOCRDocumentConverter expects an Azure Credential key"): - AzureOCRDocumentConverter(endpoint="test_endpoint") - - @pytest.mark.unit - def test_to_dict(self): - component = AzureOCRDocumentConverter(endpoint="test_endpoint", api_key="test_credential_key") - data = component.to_dict() - assert data == { - "type": "haystack.preview.components.converters.azure.AzureOCRDocumentConverter", - "init_parameters": {"endpoint": "test_endpoint", "model_id": "prebuilt-read"}, - } - - @pytest.mark.unit - def test_run(self, preview_samples_path): - with patch("haystack.preview.components.converters.azure.DocumentAnalysisClient") as mock_azure_client: - mock_result = Mock(pages=[Mock(lines=[Mock(content="mocked line 1"), Mock(content="mocked line 2")])]) - mock_result.to_dict.return_value = { - "api_version": "2023-02-28-preview", - "model_id": "prebuilt-read", - "content": "mocked line 1\nmocked line 2\n\f", - "pages": [{"lines": [{"content": "mocked line 1"}, {"content": "mocked line 2"}]}], - } - mock_azure_client.return_value.begin_analyze_document.return_value.result.return_value = mock_result - - component = AzureOCRDocumentConverter(endpoint="test_endpoint", api_key="test_credential_key") - output = component.run(paths=[preview_samples_path / "pdf" / "sample_pdf_1.pdf"]) - document = output["documents"][0] - assert document.content == "mocked line 1\nmocked line 2\n\f" - assert "raw_azure_response" in output - assert output["raw_azure_response"][0] == { - "api_version": "2023-02-28-preview", - "model_id": "prebuilt-read", - "content": "mocked line 1\nmocked line 2\n\f", - "pages": [{"lines": [{"content": "mocked line 1"}, {"content": "mocked line 2"}]}], - } - - @pytest.mark.integration - @pytest.mark.skipif(not os.environ.get("CORE_AZURE_CS_ENDPOINT", None), reason="Azure credentials not available") - @pytest.mark.skipif(not os.environ.get("CORE_AZURE_CS_API_KEY", None), reason="Azure credentials not available") - def test_run_with_pdf_file(self, preview_samples_path): - component = AzureOCRDocumentConverter( - endpoint=os.environ["CORE_AZURE_CS_ENDPOINT"], api_key=os.environ["CORE_AZURE_CS_API_KEY"] - ) - output = component.run(paths=[preview_samples_path / "pdf" / "sample_pdf_1.pdf"]) - documents = output["documents"] - assert len(documents) == 1 - assert "A sample PDF file" in documents[0].content - assert "Page 2 of Sample PDF" in documents[0].content - assert "Page 4 of Sample PDF" in documents[0].content - - @pytest.mark.integration - @pytest.mark.skipif(not os.environ.get("CORE_AZURE_CS_ENDPOINT", None), reason="Azure credentials not available") - @pytest.mark.skipif(not os.environ.get("CORE_AZURE_CS_API_KEY", None), reason="Azure credentials not available") - def test_with_image_file(self, preview_samples_path): - component = AzureOCRDocumentConverter( - endpoint=os.environ["CORE_AZURE_CS_ENDPOINT"], api_key=os.environ["CORE_AZURE_CS_API_KEY"] - ) - output = component.run(paths=[preview_samples_path / "images" / "haystack-logo.png"]) - documents = output["documents"] - assert len(documents) == 1 - assert "haystack" in documents[0].content - assert "by deepset" in documents[0].content - - @pytest.mark.integration - @pytest.mark.skipif(not os.environ.get("CORE_AZURE_CS_ENDPOINT", None), reason="Azure credentials not available") - @pytest.mark.skipif(not os.environ.get("CORE_AZURE_CS_API_KEY", None), reason="Azure credentials not available") - def test_run_with_docx_file(self, preview_samples_path): - component = AzureOCRDocumentConverter( - endpoint=os.environ["CORE_AZURE_CS_ENDPOINT"], api_key=os.environ["CORE_AZURE_CS_API_KEY"] - ) - output = component.run(paths=[preview_samples_path / "docx" / "sample_docx.docx"]) - documents = output["documents"] - assert len(documents) == 1 - assert "Sample Docx File" in documents[0].content - assert "Now we are in Page 2" in documents[0].content - assert "Page 3 was empty this is page 4" in documents[0].content diff --git a/test/preview/components/converters/test_html_to_document.py b/test/preview/components/converters/test_html_to_document.py deleted file mode 100644 index 4e182b279b..0000000000 --- a/test/preview/components/converters/test_html_to_document.py +++ /dev/null @@ -1,156 +0,0 @@ -import logging - -import pytest - -from haystack.preview.components.converters import HTMLToDocument -from haystack.preview.dataclasses import ByteStream - - -class TestHTMLToDocument: - @pytest.mark.unit - def test_run(self, preview_samples_path): - """ - Test if the component runs correctly. - """ - sources = [preview_samples_path / "html" / "what_is_haystack.html"] - converter = HTMLToDocument() - results = converter.run(sources=sources) - docs = results["documents"] - assert len(docs) == 1 - assert "Haystack" in docs[0].content - - @pytest.mark.unit - def test_run_doc_metadata(self, preview_samples_path): - """ - Test if the component runs correctly when metadata is supplied by the user. - """ - converter = HTMLToDocument() - sources = [preview_samples_path / "html" / "what_is_haystack.html"] - metadata = [{"file_name": "what_is_haystack.html"}] - results = converter.run(sources=sources, meta=metadata) - docs = results["documents"] - - assert len(docs) == 1 - assert "Haystack" in docs[0].content - assert docs[0].meta == {"file_name": "what_is_haystack.html"} - - @pytest.mark.unit - def test_incorrect_meta(self, preview_samples_path): - """ - Test if the component raises an error when incorrect metadata is supplied by the user. - """ - converter = HTMLToDocument() - sources = [preview_samples_path / "html" / "what_is_haystack.html"] - metadata = [{"file_name": "what_is_haystack.html"}, {"file_name": "haystack.html"}] - with pytest.raises(ValueError, match="The length of the metadata list must match the number of sources."): - converter.run(sources=sources, meta=metadata) - - @pytest.mark.unit - def test_run_bytestream_metadata(self, preview_samples_path): - """ - Test if the component runs correctly when metadata is read from the ByteStream object. - """ - converter = HTMLToDocument() - with open(preview_samples_path / "html" / "what_is_haystack.html", "rb") as file: - byte_stream = file.read() - stream = ByteStream(byte_stream, metadata={"content_type": "text/html", "url": "test_url"}) - - results = converter.run(sources=[stream]) - docs = results["documents"] - - assert len(docs) == 1 - assert "Haystack" in docs[0].content - assert docs[0].meta == {"content_type": "text/html", "url": "test_url"} - - @pytest.mark.unit - def test_run_bytestream_and_doc_metadata(self, preview_samples_path): - """ - Test if the component runs correctly when metadata is read from the ByteStream object and supplied by the user. - - There is no overlap between the metadata received. - """ - converter = HTMLToDocument() - with open(preview_samples_path / "html" / "what_is_haystack.html", "rb") as file: - byte_stream = file.read() - stream = ByteStream(byte_stream, metadata={"content_type": "text/html", "url": "test_url"}) - - metadata = [{"file_name": "what_is_haystack.html"}] - results = converter.run(sources=[stream], meta=metadata) - docs = results["documents"] - - assert len(docs) == 1 - assert "Haystack" in docs[0].content - assert docs[0].meta == {"file_name": "what_is_haystack.html", "content_type": "text/html", "url": "test_url"} - - @pytest.mark.unit - def test_run_bytestream_doc_overlapping_metadata(self, preview_samples_path): - """ - Test if the component runs correctly when metadata is read from the ByteStream object and supplied by the user. - - There is an overlap between the metadata received. - - The component should use the supplied metadata to overwrite the values if there is an overlap between the keys. - """ - converter = HTMLToDocument() - with open(preview_samples_path / "html" / "what_is_haystack.html", "rb") as file: - byte_stream = file.read() - # ByteStream has "url" present in metadata - stream = ByteStream(byte_stream, metadata={"content_type": "text/html", "url": "test_url_correct"}) - - # "url" supplied by the user overwrites value present in metadata - metadata = [{"file_name": "what_is_haystack.html", "url": "test_url_new"}] - results = converter.run(sources=[stream], meta=metadata) - docs = results["documents"] - - assert len(docs) == 1 - assert "Haystack" in docs[0].content - assert docs[0].meta == { - "file_name": "what_is_haystack.html", - "content_type": "text/html", - "url": "test_url_new", - } - - @pytest.mark.unit - def test_run_wrong_file_type(self, preview_samples_path, caplog): - """ - Test if the component runs correctly when an input file is not of the expected type. - """ - sources = [preview_samples_path / "audio" / "answer.wav"] - converter = HTMLToDocument() - with caplog.at_level(logging.WARNING): - results = converter.run(sources=sources) - assert "codec can't decode byte" in caplog.text - - assert results["documents"] == [] - - @pytest.mark.unit - def test_run_error_handling(self, caplog): - """ - Test if the component correctly handles errors. - """ - sources = ["non_existing_file.html"] - converter = HTMLToDocument() - with caplog.at_level(logging.WARNING): - results = converter.run(sources=sources) - assert "Could not read non_existing_file.html" in caplog.text - assert results["documents"] == [] - - @pytest.mark.unit - def test_mixed_sources_run(self, preview_samples_path): - """ - Test if the component runs correctly if the input is a mix of paths and ByteStreams. - """ - sources = [ - preview_samples_path / "html" / "what_is_haystack.html", - str((preview_samples_path / "html" / "what_is_haystack.html").absolute()), - ] - with open(preview_samples_path / "html" / "what_is_haystack.html", "rb") as f: - byte_stream = f.read() - sources.append(ByteStream(byte_stream)) - - converter = HTMLToDocument() - results = converter.run(sources=sources) - docs = results["documents"] - assert len(docs) == 3 - for doc in docs: - assert "Haystack" in doc.content diff --git a/test/preview/components/converters/test_markdown_to_document.py b/test/preview/components/converters/test_markdown_to_document.py deleted file mode 100644 index 3dc69429df..0000000000 --- a/test/preview/components/converters/test_markdown_to_document.py +++ /dev/null @@ -1,93 +0,0 @@ -import logging - -import pytest - -from haystack.preview.components.converters.markdown import MarkdownToDocument -from haystack.preview.dataclasses import ByteStream - - -class TestMarkdownToDocument: - @pytest.mark.unit - def test_init_params_default(self): - converter = MarkdownToDocument() - assert converter.table_to_single_line is False - assert converter.progress_bar is True - - @pytest.mark.unit - def test_init_params_custom(self): - converter = MarkdownToDocument(table_to_single_line=True, progress_bar=False) - assert converter.table_to_single_line is True - assert converter.progress_bar is False - - @pytest.mark.integration - def test_run(self, preview_samples_path): - converter = MarkdownToDocument() - sources = [preview_samples_path / "markdown" / "sample.md"] - results = converter.run(sources=sources) - docs = results["documents"] - - assert len(docs) == 1 - for doc in docs: - assert "What to build with Haystack" in doc.content - assert "# git clone https://github.com/deepset-ai/haystack.git" in doc.content - - @pytest.mark.integration - def test_run_metadata(self, preview_samples_path): - converter = MarkdownToDocument() - sources = [preview_samples_path / "markdown" / "sample.md"] - metadata = [{"file_name": "sample.md"}] - results = converter.run(sources=sources, meta=metadata) - docs = results["documents"] - - assert len(docs) == 1 - for doc in docs: - assert "What to build with Haystack" in doc.content - assert "# git clone https://github.com/deepset-ai/haystack.git" in doc.content - assert doc.meta == {"file_name": "sample.md"} - - @pytest.mark.integration - def test_run_wrong_file_type(self, preview_samples_path, caplog): - """ - Test if the component runs correctly when an input file is not of the expected type. - """ - sources = [preview_samples_path / "audio" / "answer.wav"] - converter = MarkdownToDocument() - with caplog.at_level(logging.WARNING): - output = converter.run(sources=sources) - assert "codec can't decode byte" in caplog.text - - docs = output["documents"] - assert not docs - - @pytest.mark.integration - def test_run_error_handling(self, caplog): - """ - Test if the component correctly handles errors. - """ - sources = ["non_existing_file.md"] - converter = MarkdownToDocument() - with caplog.at_level(logging.WARNING): - result = converter.run(sources=sources) - assert "Could not read non_existing_file.md" in caplog.text - assert not result["documents"] - - @pytest.mark.unit - def test_mixed_sources_run(self, preview_samples_path): - """ - Test if the component runs correctly if the input is a mix of strings, paths and ByteStreams. - """ - sources = [ - preview_samples_path / "markdown" / "sample.md", - str((preview_samples_path / "markdown" / "sample.md").absolute()), - ] - with open(preview_samples_path / "markdown" / "sample.md", "rb") as f: - byte_stream = f.read() - sources.append(ByteStream(byte_stream)) - - converter = MarkdownToDocument() - output = converter.run(sources=sources) - docs = output["documents"] - assert len(docs) == 3 - for doc in docs: - assert "What to build with Haystack" in doc.content - assert "# git clone https://github.com/deepset-ai/haystack.git" in doc.content diff --git a/test/preview/components/converters/test_pypdf_to_document.py b/test/preview/components/converters/test_pypdf_to_document.py deleted file mode 100644 index e7fb0202fb..0000000000 --- a/test/preview/components/converters/test_pypdf_to_document.py +++ /dev/null @@ -1,77 +0,0 @@ -import logging -import pytest -from pypdf import PdfReader - -from haystack.preview import Document -from haystack.preview.components.converters.pypdf import PyPDFToDocument, CONVERTERS_REGISTRY -from haystack.preview.dataclasses import ByteStream - - -class TestPyPDFToDocument: - def test_init(self): - component = PyPDFToDocument() - assert component.converter_name == "default" - assert hasattr(component, "_converter") - - def test_init_fail_nonexisting_converter(self): - with pytest.raises(ValueError): - PyPDFToDocument(converter_name="non_existing_converter") - - @pytest.mark.unit - def test_run(self, preview_samples_path): - """ - Test if the component runs correctly. - """ - paths = [preview_samples_path / "pdf" / "react_paper.pdf"] - converter = PyPDFToDocument() - output = converter.run(sources=paths) - docs = output["documents"] - assert len(docs) == 1 - assert "ReAct" in docs[0].content - - @pytest.mark.unit - def test_run_error_handling(self, preview_samples_path, caplog): - """ - Test if the component correctly handles errors. - """ - paths = ["non_existing_file.pdf"] - converter = PyPDFToDocument() - with caplog.at_level(logging.WARNING): - converter.run(sources=paths) - assert "Could not read non_existing_file.pdf" in caplog.text - - @pytest.mark.unit - def test_mixed_sources_run(self, preview_samples_path): - """ - Test if the component runs correctly when mixed sources are provided. - """ - paths = [preview_samples_path / "pdf" / "react_paper.pdf"] - with open(preview_samples_path / "pdf" / "react_paper.pdf", "rb") as f: - paths.append(ByteStream(f.read())) - - converter = PyPDFToDocument() - output = converter.run(sources=paths) - docs = output["documents"] - assert len(docs) == 2 - assert "ReAct" in docs[0].content - assert "ReAct" in docs[1].content - - @pytest.mark.unit - def test_custom_converter(self, preview_samples_path): - """ - Test if the component correctly handles custom converters. - """ - paths = [preview_samples_path / "pdf" / "react_paper.pdf"] - - class MyCustomConverter: - def convert(self, reader: PdfReader) -> Document: - return Document(content="I don't care about converting given pdfs, I always return this") - - CONVERTERS_REGISTRY["custom"] = MyCustomConverter() - - converter = PyPDFToDocument(converter_name="custom") - output = converter.run(sources=paths) - docs = output["documents"] - assert len(docs) == 1 - assert "ReAct" not in docs[0].content - assert "I don't care about converting given pdfs, I always return this" in docs[0].content diff --git a/test/preview/components/converters/test_textfile_to_document.py b/test/preview/components/converters/test_textfile_to_document.py deleted file mode 100644 index aafa77d1e2..0000000000 --- a/test/preview/components/converters/test_textfile_to_document.py +++ /dev/null @@ -1,69 +0,0 @@ -import logging -from unittest.mock import patch -from pathlib import Path - -import pytest - -from haystack.preview.dataclasses import ByteStream -from haystack.preview.components.converters.txt import TextFileToDocument - - -class TestTextfileToDocument: - @pytest.mark.unit - def test_run(self, preview_samples_path): - """ - Test if the component runs correctly. - """ - bytestream = ByteStream.from_file_path(preview_samples_path / "txt" / "doc_3.txt") - bytestream.metadata["file_path"] = str(preview_samples_path / "txt" / "doc_3.txt") - bytestream.metadata["key"] = "value" - files = [ - str(preview_samples_path / "txt" / "doc_1.txt"), - preview_samples_path / "txt" / "doc_2.txt", - bytestream, - ] - converter = TextFileToDocument() - output = converter.run(sources=files) - docs = output["documents"] - assert len(docs) == 3 - assert "Some text for testing." in docs[0].content - assert "This is a test line." in docs[1].content - assert "That's yet another file!" in docs[2].content - assert docs[0].meta["file_path"] == str(files[0]) - assert docs[1].meta["file_path"] == str(files[1]) - assert docs[2].meta == bytestream.metadata - - @pytest.mark.unit - def test_run_error_handling(self, preview_samples_path, caplog): - """ - Test if the component correctly handles errors. - """ - paths = [ - preview_samples_path / "txt" / "doc_1.txt", - "non_existing_file.txt", - preview_samples_path / "txt" / "doc_3.txt", - ] - converter = TextFileToDocument() - with caplog.at_level(logging.WARNING): - output = converter.run(sources=paths) - assert "non_existing_file.txt" in caplog.text - docs = output["documents"] - assert len(docs) == 2 - assert docs[0].meta["file_path"] == str(paths[0]) - assert docs[1].meta["file_path"] == str(paths[2]) - - @pytest.mark.unit - def test_encoding_override(self, preview_samples_path): - """ - Test if the encoding metadata field is used properly - """ - bytestream = ByteStream.from_file_path(preview_samples_path / "txt" / "doc_1.txt") - bytestream.metadata["key"] = "value" - - converter = TextFileToDocument(encoding="utf-16") - output = converter.run(sources=[bytestream]) - assert "Some text for testing." not in output["documents"][0].content - - bytestream.metadata["encoding"] = "utf-8" - output = converter.run(sources=[bytestream]) - assert "Some text for testing." in output["documents"][0].content diff --git a/test/preview/components/converters/test_tika_doc_converter.py b/test/preview/components/converters/test_tika_doc_converter.py deleted file mode 100644 index c346c4bf95..0000000000 --- a/test/preview/components/converters/test_tika_doc_converter.py +++ /dev/null @@ -1,75 +0,0 @@ -from unittest.mock import patch - -import pytest - -from haystack.preview.components.converters.tika import TikaDocumentConverter - - -class TestTikaDocumentConverter: - @pytest.mark.unit - def test_run(self): - component = TikaDocumentConverter() - with patch("haystack.preview.components.converters.tika.tika_parser.from_file") as mock_tika_parser: - mock_tika_parser.return_value = {"content": "Content of mock_file.pdf"} - documents = component.run(paths=["mock_file.pdf"])["documents"] - - assert len(documents) == 1 - assert documents[0].content == "Content of mock_file.pdf" - - @pytest.mark.unit - def test_run_logs_warning_if_content_empty(self, caplog): - component = TikaDocumentConverter() - with patch("haystack.preview.components.converters.tika.tika_parser.from_file") as mock_tika_parser: - mock_tika_parser.return_value = {"content": ""} - with caplog.at_level("WARNING"): - component.run(paths=["mock_file.pdf"]) - assert "Skipping file at 'mock_file.pdf' as Tika was not able to extract any content." in caplog.text - - @pytest.mark.unit - def test_run_logs_error(self, caplog): - component = TikaDocumentConverter() - with patch("haystack.preview.components.converters.tika.tika_parser.from_file") as mock_tika_parser: - mock_tika_parser.side_effect = Exception("Some error") - with caplog.at_level("ERROR"): - component.run(paths=["mock_file.pdf"]) - assert "Could not convert file at 'mock_file.pdf' to Document. Error: Some error" in caplog.text - - @pytest.mark.integration - def test_run_with_txt_files(self, preview_samples_path): - component = TikaDocumentConverter() - output = component.run( - paths=[preview_samples_path / "txt" / "doc_1.txt", preview_samples_path / "txt" / "doc_2.txt"] - ) - documents = output["documents"] - assert len(documents) == 2 - assert "Some text for testing.\nTwo lines in here." in documents[0].content - assert "This is a test line.\n123 456 789\n987 654 321" in documents[1].content - - @pytest.mark.integration - def test_run_with_pdf_file(self, preview_samples_path): - component = TikaDocumentConverter() - output = component.run( - paths=[preview_samples_path / "pdf" / "sample_pdf_1.pdf", preview_samples_path / "pdf" / "sample_pdf_2.pdf"] - ) - documents = output["documents"] - assert len(documents) == 2 - assert "A sample PDF file" in documents[0].content - assert "Page 2 of Sample PDF" in documents[0].content - assert "Page 4 of Sample PDF" in documents[0].content - assert "First Page" in documents[1].content - assert ( - "Wiki engines usually allow content to be written using a simplified markup language" - in documents[1].content - ) - assert "This section needs additional citations for verification." in documents[1].content - assert "This would make it easier for other users to find the article." in documents[1].content - - @pytest.mark.integration - def test_run_with_docx_file(self, preview_samples_path): - component = TikaDocumentConverter() - output = component.run(paths=[preview_samples_path / "docx" / "sample_docx.docx"]) - documents = output["documents"] - assert len(documents) == 1 - assert "Sample Docx File" in documents[0].content - assert "Now we are in Page 2" in documents[0].content - assert "Page 3 was empty this is page 4" in documents[0].content diff --git a/test/preview/components/embedders/__init__.py b/test/preview/components/embedders/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/test/preview/components/embedders/test_openai_document_embedder.py b/test/preview/components/embedders/test_openai_document_embedder.py deleted file mode 100644 index 954846c480..0000000000 --- a/test/preview/components/embedders/test_openai_document_embedder.py +++ /dev/null @@ -1,288 +0,0 @@ -from unittest.mock import patch -from typing import List, cast - -import pytest -import numpy as np -import openai -from openai.util import convert_to_openai_object -from openai.openai_object import OpenAIObject - -from haystack.preview import Document -from haystack.preview.components.embedders.openai_document_embedder import OpenAIDocumentEmbedder - - -def mock_openai_response(input: List[str], model: str = "text-embedding-ada-002", **kwargs) -> OpenAIObject: - dict_response = { - "object": "list", - "data": [ - {"object": "embedding", "index": i, "embedding": np.random.rand(1536).tolist()} for i in range(len(input)) - ], - "model": model, - "usage": {"prompt_tokens": 4, "total_tokens": 4}, - } - - return cast(OpenAIObject, convert_to_openai_object(dict_response)) - - -class TestOpenAIDocumentEmbedder: - @pytest.mark.unit - def test_init_default(self, monkeypatch): - openai.api_key = None - monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") - embedder = OpenAIDocumentEmbedder() - - assert openai.api_key == "fake-api-key" - - assert embedder.model_name == "text-embedding-ada-002" - assert embedder.organization is None - assert embedder.prefix == "" - assert embedder.suffix == "" - assert embedder.batch_size == 32 - assert embedder.progress_bar is True - assert embedder.metadata_fields_to_embed == [] - assert embedder.embedding_separator == "\n" - - @pytest.mark.unit - def test_init_with_parameters(self): - embedder = OpenAIDocumentEmbedder( - api_key="fake-api-key", - model_name="model", - organization="my-org", - prefix="prefix", - suffix="suffix", - batch_size=64, - progress_bar=False, - metadata_fields_to_embed=["test_field"], - embedding_separator=" | ", - ) - assert openai.api_key == "fake-api-key" - assert openai.organization == "my-org" - - assert embedder.organization == "my-org" - assert embedder.model_name == "model" - assert embedder.prefix == "prefix" - assert embedder.suffix == "suffix" - assert embedder.batch_size == 64 - assert embedder.progress_bar is False - assert embedder.metadata_fields_to_embed == ["test_field"] - assert embedder.embedding_separator == " | " - - @pytest.mark.unit - def test_init_fail_wo_api_key(self, monkeypatch): - openai.api_key = None - monkeypatch.delenv("OPENAI_API_KEY", raising=False) - with pytest.raises(ValueError, match="OpenAIDocumentEmbedder expects an OpenAI API key"): - OpenAIDocumentEmbedder() - - @pytest.mark.unit - def test_to_dict(self): - component = OpenAIDocumentEmbedder(api_key="fake-api-key") - data = component.to_dict() - assert data == { - "type": "haystack.preview.components.embedders.openai_document_embedder.OpenAIDocumentEmbedder", - "init_parameters": { - "model_name": "text-embedding-ada-002", - "organization": None, - "prefix": "", - "suffix": "", - "batch_size": 32, - "progress_bar": True, - "metadata_fields_to_embed": [], - "embedding_separator": "\n", - }, - } - - @pytest.mark.unit - def test_to_dict_with_custom_init_parameters(self): - component = OpenAIDocumentEmbedder( - api_key="fake-api-key", - model_name="model", - organization="my-org", - prefix="prefix", - suffix="suffix", - batch_size=64, - progress_bar=False, - metadata_fields_to_embed=["test_field"], - embedding_separator=" | ", - ) - data = component.to_dict() - assert data == { - "type": "haystack.preview.components.embedders.openai_document_embedder.OpenAIDocumentEmbedder", - "init_parameters": { - "model_name": "model", - "organization": "my-org", - "prefix": "prefix", - "suffix": "suffix", - "batch_size": 64, - "progress_bar": False, - "metadata_fields_to_embed": ["test_field"], - "embedding_separator": " | ", - }, - } - - @pytest.mark.unit - def test_prepare_texts_to_embed_w_metadata(self): - documents = [ - Document(content=f"document number {i}:\ncontent", meta={"meta_field": f"meta_value {i}"}) for i in range(5) - ] - - embedder = OpenAIDocumentEmbedder( - api_key="fake-api-key", metadata_fields_to_embed=["meta_field"], embedding_separator=" | " - ) - - prepared_texts = embedder._prepare_texts_to_embed(documents) - - # note that newline is replaced by space - assert prepared_texts == [ - "meta_value 0 | document number 0: content", - "meta_value 1 | document number 1: content", - "meta_value 2 | document number 2: content", - "meta_value 3 | document number 3: content", - "meta_value 4 | document number 4: content", - ] - - @pytest.mark.unit - def test_prepare_texts_to_embed_w_suffix(self): - documents = [Document(content=f"document number {i}") for i in range(5)] - - embedder = OpenAIDocumentEmbedder(api_key="fake-api-key", prefix="my_prefix ", suffix=" my_suffix") - - prepared_texts = embedder._prepare_texts_to_embed(documents) - - assert prepared_texts == [ - "my_prefix document number 0 my_suffix", - "my_prefix document number 1 my_suffix", - "my_prefix document number 2 my_suffix", - "my_prefix document number 3 my_suffix", - "my_prefix document number 4 my_suffix", - ] - - @pytest.mark.unit - def test_embed_batch(self): - texts = ["text 1", "text 2", "text 3", "text 4", "text 5"] - - with patch( - "haystack.preview.components.embedders.openai_document_embedder.openai.Embedding" - ) as openai_embedding_patch: - openai_embedding_patch.create.side_effect = mock_openai_response - embedder = OpenAIDocumentEmbedder(api_key="fake-api-key", model_name="model") - - embeddings, metadata = embedder._embed_batch(texts_to_embed=texts, batch_size=2) - - assert openai_embedding_patch.create.call_count == 3 - - assert isinstance(embeddings, list) - assert len(embeddings) == len(texts) - for embedding in embeddings: - assert isinstance(embedding, list) - assert len(embedding) == 1536 - assert all(isinstance(x, float) for x in embedding) - - # openai.Embedding.create is called 3 times - assert metadata == {"model": "model", "usage": {"prompt_tokens": 3 * 4, "total_tokens": 3 * 4}} - - @pytest.mark.unit - def test_run(self): - docs = [ - Document(content="I love cheese", meta={"topic": "Cuisine"}), - Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), - ] - - model = "text-similarity-ada-001" - with patch( - "haystack.preview.components.embedders.openai_document_embedder.openai.Embedding" - ) as openai_embedding_patch: - openai_embedding_patch.create.side_effect = mock_openai_response - embedder = OpenAIDocumentEmbedder( - api_key="fake-api-key", - model_name=model, - prefix="prefix ", - suffix=" suffix", - metadata_fields_to_embed=["topic"], - embedding_separator=" | ", - ) - - result = embedder.run(documents=docs) - - openai_embedding_patch.create.assert_called_once_with( - model=model, - input=[ - "prefix Cuisine | I love cheese suffix", - "prefix ML | A transformer is a deep learning architecture suffix", - ], - ) - documents_with_embeddings = result["documents"] - metadata = result["metadata"] - - assert isinstance(documents_with_embeddings, list) - assert len(documents_with_embeddings) == len(docs) - for doc in documents_with_embeddings: - assert isinstance(doc, Document) - assert isinstance(doc.embedding, list) - assert len(doc.embedding) == 1536 - assert all(isinstance(x, float) for x in doc.embedding) - assert metadata == {"model": model, "usage": {"prompt_tokens": 4, "total_tokens": 4}} - - @pytest.mark.unit - def test_run_custom_batch_size(self): - docs = [ - Document(content="I love cheese", meta={"topic": "Cuisine"}), - Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), - ] - - model = "text-similarity-ada-001" - with patch( - "haystack.preview.components.embedders.openai_document_embedder.openai.Embedding" - ) as openai_embedding_patch: - openai_embedding_patch.create.side_effect = mock_openai_response - embedder = OpenAIDocumentEmbedder( - api_key="fake-api-key", - model_name=model, - prefix="prefix ", - suffix=" suffix", - metadata_fields_to_embed=["topic"], - embedding_separator=" | ", - batch_size=1, - ) - - result = embedder.run(documents=docs) - - assert openai_embedding_patch.create.call_count == 2 - - documents_with_embeddings = result["documents"] - metadata = result["metadata"] - - assert isinstance(documents_with_embeddings, list) - assert len(documents_with_embeddings) == len(docs) - for doc in documents_with_embeddings: - assert isinstance(doc, Document) - assert isinstance(doc.embedding, list) - assert len(doc.embedding) == 1536 - assert all(isinstance(x, float) for x in doc.embedding) - - # openai.Embedding.create is called 2 times - assert metadata == {"model": model, "usage": {"prompt_tokens": 2 * 4, "total_tokens": 2 * 4}} - - @pytest.mark.unit - def test_run_wrong_input_format(self): - embedder = OpenAIDocumentEmbedder(api_key="fake-api-key") - - # wrong formats - string_input = "text" - list_integers_input = [1, 2, 3] - - with pytest.raises(TypeError, match="OpenAIDocumentEmbedder expects a list of Documents as input"): - embedder.run(documents=string_input) - - with pytest.raises(TypeError, match="OpenAIDocumentEmbedder expects a list of Documents as input"): - embedder.run(documents=list_integers_input) - - @pytest.mark.unit - def test_run_on_empty_list(self): - embedder = OpenAIDocumentEmbedder(api_key="fake-api-key") - - empty_list_input = [] - result = embedder.run(documents=empty_list_input) - - assert result["documents"] is not None - assert not result["documents"] # empty list diff --git a/test/preview/components/embedders/test_openai_text_embedder.py b/test/preview/components/embedders/test_openai_text_embedder.py deleted file mode 100644 index 50be49ac5d..0000000000 --- a/test/preview/components/embedders/test_openai_text_embedder.py +++ /dev/null @@ -1,118 +0,0 @@ -from unittest.mock import patch -import pytest -import openai -from openai.util import convert_to_openai_object -import numpy as np - -from haystack.preview.components.embedders.openai_text_embedder import OpenAITextEmbedder - - -def mock_openai_response(model: str = "text-embedding-ada-002", **kwargs) -> openai.openai_object.OpenAIObject: - dict_response = { - "object": "list", - "data": [{"object": "embedding", "index": 0, "embedding": np.random.rand(1536).tolist()}], - "model": model, - "usage": {"prompt_tokens": 4, "total_tokens": 4}, - } - - return convert_to_openai_object(dict_response) - - -class TestOpenAITextEmbedder: - @pytest.mark.unit - def test_init_default(self, monkeypatch): - openai.api_key = None - monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") - embedder = OpenAITextEmbedder() - - assert openai.api_key == "fake-api-key" - assert embedder.model_name == "text-embedding-ada-002" - assert embedder.organization is None - assert embedder.prefix == "" - assert embedder.suffix == "" - - @pytest.mark.unit - def test_init_with_parameters(self): - embedder = OpenAITextEmbedder( - api_key="fake-api-key", - model_name="model", - organization="fake-organization", - prefix="prefix", - suffix="suffix", - ) - assert openai.api_key == "fake-api-key" - assert embedder.model_name == "model" - assert embedder.organization == "fake-organization" - assert openai.organization == "fake-organization" - assert embedder.prefix == "prefix" - assert embedder.suffix == "suffix" - - @pytest.mark.unit - def test_init_fail_wo_api_key(self, monkeypatch): - openai.api_key = None - monkeypatch.delenv("OPENAI_API_KEY", raising=False) - with pytest.raises(ValueError, match="OpenAITextEmbedder expects an OpenAI API key"): - OpenAITextEmbedder() - - @pytest.mark.unit - def test_to_dict(self): - component = OpenAITextEmbedder(api_key="fake-api-key") - data = component.to_dict() - assert data == { - "type": "haystack.preview.components.embedders.openai_text_embedder.OpenAITextEmbedder", - "init_parameters": { - "model_name": "text-embedding-ada-002", - "organization": None, - "prefix": "", - "suffix": "", - }, - } - - @pytest.mark.unit - def test_to_dict_with_custom_init_parameters(self): - component = OpenAITextEmbedder( - api_key="fake-api-key", - model_name="model", - organization="fake-organization", - prefix="prefix", - suffix="suffix", - ) - data = component.to_dict() - assert data == { - "type": "haystack.preview.components.embedders.openai_text_embedder.OpenAITextEmbedder", - "init_parameters": { - "model_name": "model", - "organization": "fake-organization", - "prefix": "prefix", - "suffix": "suffix", - }, - } - - @pytest.mark.unit - def test_run(self): - model = "text-similarity-ada-001" - - with patch( - "haystack.preview.components.embedders.openai_text_embedder.openai.Embedding" - ) as openai_embedding_patch: - openai_embedding_patch.create.side_effect = mock_openai_response - - embedder = OpenAITextEmbedder(api_key="fake-api-key", model_name=model, prefix="prefix ", suffix=" suffix") - result = embedder.run(text="The food was delicious") - - openai_embedding_patch.create.assert_called_once_with( - model=model, input="prefix The food was delicious suffix" - ) - - assert len(result["embedding"]) == 1536 - assert all(isinstance(x, float) for x in result["embedding"]) - assert result["metadata"] == {"model": model, "usage": {"prompt_tokens": 4, "total_tokens": 4}} - - @pytest.mark.unit - def test_run_wrong_input_format(self): - embedder = OpenAITextEmbedder(api_key="fake-api-key") - - list_integers_input = [1, 2, 3] - - with pytest.raises(TypeError, match="OpenAITextEmbedder expects a string as an input"): - embedder.run(text=list_integers_input) diff --git a/test/preview/components/embedders/test_sentence_transformers_document_embedder.py b/test/preview/components/embedders/test_sentence_transformers_document_embedder.py deleted file mode 100644 index 2f5e5e667f..0000000000 --- a/test/preview/components/embedders/test_sentence_transformers_document_embedder.py +++ /dev/null @@ -1,210 +0,0 @@ -from unittest.mock import patch, MagicMock -import pytest -import numpy as np - -from haystack.preview import Document -from haystack.preview.components.embedders.sentence_transformers_document_embedder import ( - SentenceTransformersDocumentEmbedder, -) - - -class TestSentenceTransformersDocumentEmbedder: - @pytest.mark.unit - def test_init_default(self): - embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model") - assert embedder.model_name_or_path == "model" - assert embedder.device == "cpu" - assert embedder.token is None - assert embedder.prefix == "" - assert embedder.suffix == "" - assert embedder.batch_size == 32 - assert embedder.progress_bar is True - assert embedder.normalize_embeddings is False - assert embedder.metadata_fields_to_embed == [] - assert embedder.embedding_separator == "\n" - - @pytest.mark.unit - def test_init_with_parameters(self): - embedder = SentenceTransformersDocumentEmbedder( - model_name_or_path="model", - device="cuda", - token=True, - prefix="prefix", - suffix="suffix", - batch_size=64, - progress_bar=False, - normalize_embeddings=True, - metadata_fields_to_embed=["test_field"], - embedding_separator=" | ", - ) - assert embedder.model_name_or_path == "model" - assert embedder.device == "cuda" - assert embedder.token is True - assert embedder.prefix == "prefix" - assert embedder.suffix == "suffix" - assert embedder.batch_size == 64 - assert embedder.progress_bar is False - assert embedder.normalize_embeddings is True - assert embedder.metadata_fields_to_embed == ["test_field"] - assert embedder.embedding_separator == " | " - - @pytest.mark.unit - def test_to_dict(self): - component = SentenceTransformersDocumentEmbedder(model_name_or_path="model") - data = component.to_dict() - assert data == { - "type": "haystack.preview.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder", - "init_parameters": { - "model_name_or_path": "model", - "device": "cpu", - "token": None, - "prefix": "", - "suffix": "", - "batch_size": 32, - "progress_bar": True, - "normalize_embeddings": False, - "embedding_separator": "\n", - "metadata_fields_to_embed": [], - }, - } - - @pytest.mark.unit - def test_to_dict_with_custom_init_parameters(self): - component = SentenceTransformersDocumentEmbedder( - model_name_or_path="model", - device="cuda", - token="the-token", - prefix="prefix", - suffix="suffix", - batch_size=64, - progress_bar=False, - normalize_embeddings=True, - metadata_fields_to_embed=["meta_field"], - embedding_separator=" - ", - ) - data = component.to_dict() - - assert data == { - "type": "haystack.preview.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder", - "init_parameters": { - "model_name_or_path": "model", - "device": "cuda", - "token": None, # the token is not serialized - "prefix": "prefix", - "suffix": "suffix", - "batch_size": 64, - "progress_bar": False, - "normalize_embeddings": True, - "embedding_separator": " - ", - "metadata_fields_to_embed": ["meta_field"], - }, - } - - @pytest.mark.unit - @patch( - "haystack.preview.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory" - ) - def test_warmup(self, mocked_factory): - embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model") - mocked_factory.get_embedding_backend.assert_not_called() - embedder.warm_up() - mocked_factory.get_embedding_backend.assert_called_once_with( - model_name_or_path="model", device="cpu", use_auth_token=None - ) - - @pytest.mark.unit - @patch( - "haystack.preview.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory" - ) - def test_warmup_doesnt_reload(self, mocked_factory): - embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model") - mocked_factory.get_embedding_backend.assert_not_called() - embedder.warm_up() - embedder.warm_up() - mocked_factory.get_embedding_backend.assert_called_once() - - @pytest.mark.unit - def test_run(self): - embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model") - embedder.embedding_backend = MagicMock() - embedder.embedding_backend.embed = lambda x, **kwargs: np.random.rand(len(x), 16).tolist() - - documents = [Document(content=f"document number {i}") for i in range(5)] - - result = embedder.run(documents=documents) - - assert isinstance(result["documents"], list) - assert len(result["documents"]) == len(documents) - for doc in result["documents"]: - assert isinstance(doc, Document) - assert isinstance(doc.embedding, list) - assert isinstance(doc.embedding[0], float) - - @pytest.mark.unit - def test_run_wrong_input_format(self): - embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model") - - string_input = "text" - list_integers_input = [1, 2, 3] - - with pytest.raises( - TypeError, match="SentenceTransformersDocumentEmbedder expects a list of Documents as input" - ): - embedder.run(documents=string_input) - - with pytest.raises( - TypeError, match="SentenceTransformersDocumentEmbedder expects a list of Documents as input" - ): - embedder.run(documents=list_integers_input) - - @pytest.mark.unit - def test_embed_metadata(self): - embedder = SentenceTransformersDocumentEmbedder( - model_name_or_path="model", metadata_fields_to_embed=["meta_field"], embedding_separator="\n" - ) - embedder.embedding_backend = MagicMock() - - documents = [Document(content=f"document number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)] - - embedder.run(documents=documents) - - embedder.embedding_backend.embed.assert_called_once_with( - [ - "meta_value 0\ndocument number 0", - "meta_value 1\ndocument number 1", - "meta_value 2\ndocument number 2", - "meta_value 3\ndocument number 3", - "meta_value 4\ndocument number 4", - ], - batch_size=32, - show_progress_bar=True, - normalize_embeddings=False, - ) - - @pytest.mark.unit - def test_prefix_suffix(self): - embedder = SentenceTransformersDocumentEmbedder( - model_name_or_path="model", - prefix="my_prefix ", - suffix=" my_suffix", - metadata_fields_to_embed=["meta_field"], - embedding_separator="\n", - ) - embedder.embedding_backend = MagicMock() - - documents = [Document(content=f"document number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)] - - embedder.run(documents=documents) - - embedder.embedding_backend.embed.assert_called_once_with( - [ - "my_prefix meta_value 0\ndocument number 0 my_suffix", - "my_prefix meta_value 1\ndocument number 1 my_suffix", - "my_prefix meta_value 2\ndocument number 2 my_suffix", - "my_prefix meta_value 3\ndocument number 3 my_suffix", - "my_prefix meta_value 4\ndocument number 4 my_suffix", - ], - batch_size=32, - show_progress_bar=True, - normalize_embeddings=False, - ) diff --git a/test/preview/components/embedders/test_sentence_transformers_embedding_backend.py b/test/preview/components/embedders/test_sentence_transformers_embedding_backend.py deleted file mode 100644 index 4ac8c55869..0000000000 --- a/test/preview/components/embedders/test_sentence_transformers_embedding_backend.py +++ /dev/null @@ -1,42 +0,0 @@ -from unittest.mock import patch -import pytest -from haystack.preview.components.embedders.backends.sentence_transformers_backend import ( - _SentenceTransformersEmbeddingBackendFactory, -) - - -@pytest.mark.unit -@patch("haystack.preview.components.embedders.backends.sentence_transformers_backend.SentenceTransformer") -def test_factory_behavior(mock_sentence_transformer): - embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend( - model_name_or_path="my_model", device="cpu" - ) - same_embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend("my_model", "cpu") - another_embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend( - model_name_or_path="another_model", device="cpu" - ) - - assert same_embedding_backend is embedding_backend - assert another_embedding_backend is not embedding_backend - - -@pytest.mark.unit -@patch("haystack.preview.components.embedders.backends.sentence_transformers_backend.SentenceTransformer") -def test_model_initialization(mock_sentence_transformer): - _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend( - model_name_or_path="model", device="cpu", use_auth_token="my_token" - ) - mock_sentence_transformer.assert_called_once_with( - model_name_or_path="model", device="cpu", use_auth_token="my_token" - ) - - -@pytest.mark.unit -@patch("haystack.preview.components.embedders.backends.sentence_transformers_backend.SentenceTransformer") -def test_embedding_function_with_kwargs(mock_sentence_transformer): - embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(model_name_or_path="model") - - data = ["sentence1", "sentence2"] - embedding_backend.embed(data=data, normalize_embeddings=True) - - embedding_backend.model.encode.assert_called_once_with(data, normalize_embeddings=True) diff --git a/test/preview/components/embedders/test_sentence_transformers_text_embedder.py b/test/preview/components/embedders/test_sentence_transformers_text_embedder.py deleted file mode 100644 index d93e576ac8..0000000000 --- a/test/preview/components/embedders/test_sentence_transformers_text_embedder.py +++ /dev/null @@ -1,151 +0,0 @@ -from unittest.mock import patch, MagicMock -import pytest - -import numpy as np - -from haystack.preview.components.embedders.sentence_transformers_text_embedder import SentenceTransformersTextEmbedder - - -class TestSentenceTransformersTextEmbedder: - @pytest.mark.unit - def test_init_default(self): - embedder = SentenceTransformersTextEmbedder(model_name_or_path="model") - assert embedder.model_name_or_path == "model" - assert embedder.device == "cpu" - assert embedder.token is None - assert embedder.prefix == "" - assert embedder.suffix == "" - assert embedder.batch_size == 32 - assert embedder.progress_bar is True - assert embedder.normalize_embeddings is False - - @pytest.mark.unit - def test_init_with_parameters(self): - embedder = SentenceTransformersTextEmbedder( - model_name_or_path="model", - device="cuda", - token=True, - prefix="prefix", - suffix="suffix", - batch_size=64, - progress_bar=False, - normalize_embeddings=True, - ) - assert embedder.model_name_or_path == "model" - assert embedder.device == "cuda" - assert embedder.token is True - assert embedder.prefix == "prefix" - assert embedder.suffix == "suffix" - assert embedder.batch_size == 64 - assert embedder.progress_bar is False - assert embedder.normalize_embeddings is True - - @pytest.mark.unit - def test_to_dict(self): - component = SentenceTransformersTextEmbedder(model_name_or_path="model") - data = component.to_dict() - assert data == { - "type": "haystack.preview.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder", - "init_parameters": { - "model_name_or_path": "model", - "device": "cpu", - "token": None, - "prefix": "", - "suffix": "", - "batch_size": 32, - "progress_bar": True, - "normalize_embeddings": False, - }, - } - - @pytest.mark.unit - def test_to_dict_with_custom_init_parameters(self): - component = SentenceTransformersTextEmbedder( - model_name_or_path="model", - device="cuda", - token=True, - prefix="prefix", - suffix="suffix", - batch_size=64, - progress_bar=False, - normalize_embeddings=True, - ) - data = component.to_dict() - assert data == { - "type": "haystack.preview.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder", - "init_parameters": { - "model_name_or_path": "model", - "device": "cuda", - "token": True, - "prefix": "prefix", - "suffix": "suffix", - "batch_size": 64, - "progress_bar": False, - "normalize_embeddings": True, - }, - } - - @pytest.mark.unit - def test_to_dict_not_serialize_token(self): - component = SentenceTransformersTextEmbedder(model_name_or_path="model", token="awesome-token") - data = component.to_dict() - assert data == { - "type": "haystack.preview.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder", - "init_parameters": { - "model_name_or_path": "model", - "device": "cpu", - "token": None, - "prefix": "", - "suffix": "", - "batch_size": 32, - "progress_bar": True, - "normalize_embeddings": False, - }, - } - - @pytest.mark.unit - @patch( - "haystack.preview.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory" - ) - def test_warmup(self, mocked_factory): - embedder = SentenceTransformersTextEmbedder(model_name_or_path="model") - mocked_factory.get_embedding_backend.assert_not_called() - embedder.warm_up() - mocked_factory.get_embedding_backend.assert_called_once_with( - model_name_or_path="model", device="cpu", use_auth_token=None - ) - - @pytest.mark.unit - @patch( - "haystack.preview.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory" - ) - def test_warmup_doesnt_reload(self, mocked_factory): - embedder = SentenceTransformersTextEmbedder(model_name_or_path="model") - mocked_factory.get_embedding_backend.assert_not_called() - embedder.warm_up() - embedder.warm_up() - mocked_factory.get_embedding_backend.assert_called_once() - - @pytest.mark.unit - def test_run(self): - embedder = SentenceTransformersTextEmbedder(model_name_or_path="model") - embedder.embedding_backend = MagicMock() - embedder.embedding_backend.embed = lambda x, **kwargs: np.random.rand(len(x), 16).tolist() - - text = "a nice text to embed" - - result = embedder.run(text=text) - embedding = result["embedding"] - - assert isinstance(embedding, list) - assert all(isinstance(el, float) for el in embedding) - - @pytest.mark.unit - def test_run_wrong_input_format(self): - embedder = SentenceTransformersTextEmbedder(model_name_or_path="model") - embedder.embedding_backend = MagicMock() - - list_integers_input = [1, 2, 3] - - with pytest.raises(TypeError, match="SentenceTransformersTextEmbedder expects a string as input"): - embedder.run(text=list_integers_input) diff --git a/test/preview/components/fetchers/__init__.py b/test/preview/components/fetchers/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/test/preview/components/fetchers/test_link_content_fetcher.py b/test/preview/components/fetchers/test_link_content_fetcher.py deleted file mode 100644 index 7d7d4f904d..0000000000 --- a/test/preview/components/fetchers/test_link_content_fetcher.py +++ /dev/null @@ -1,196 +0,0 @@ -from unittest.mock import patch, Mock - -import pytest -import requests - -from haystack.preview.components.fetchers.link_content import ( - LinkContentFetcher, - text_content_handler, - binary_content_handler, - DEFAULT_USER_AGENT, -) - -HTML_URL = "https://docs.haystack.deepset.ai/docs" -TEXT_URL = "https://raw.githubusercontent.com/deepset-ai/haystack/main/README.md" -PDF_URL = "https://raw.githubusercontent.com/deepset-ai/haystack/b5987a6d8d0714eb2f3011183ab40093d2e4a41a/e2e/samples/pipelines/sample_pdf_1.pdf" - - -@pytest.fixture -def mock_get_link_text_content(): - with patch("haystack.preview.components.fetchers.link_content.requests") as mock_run: - mock_run.get.return_value = Mock( - status_code=200, text="Example test response", headers={"Content-Type": "text/plain"} - ) - yield mock_run - - -@pytest.fixture -def mock_get_link_content(test_files_path): - with patch("haystack.preview.components.fetchers.link_content.requests") as mock_run: - mock_run.get.return_value = Mock( - status_code=200, - content=open(test_files_path / "pdf" / "sample_pdf_1.pdf", "rb").read(), - headers={"Content-Type": "application/pdf"}, - ) - yield mock_run - - -class TestLinkContentFetcher: - @pytest.mark.unit - def test_init(self): - fetcher = LinkContentFetcher() - assert fetcher.raise_on_failure is True - assert fetcher.user_agents == [DEFAULT_USER_AGENT] - assert fetcher.retry_attempts == 2 - assert fetcher.timeout == 3 - assert fetcher.handlers == { - "text/html": text_content_handler, - "text/plain": text_content_handler, - "application/pdf": binary_content_handler, - "application/octet-stream": binary_content_handler, - } - assert hasattr(fetcher, "_get_response") - - @pytest.mark.unit - def test_init_with_params(self): - fetcher = LinkContentFetcher(raise_on_failure=False, user_agents=["test"], retry_attempts=1, timeout=2) - assert fetcher.raise_on_failure is False - assert fetcher.user_agents == ["test"] - assert fetcher.retry_attempts == 1 - assert fetcher.timeout == 2 - - @pytest.mark.unit - def test_run_text(self): - correct_response = b"Example test response" - with patch("haystack.preview.components.fetchers.link_content.requests") as mock_run: - mock_run.get.return_value = Mock( - status_code=200, text="Example test response", headers={"Content-Type": "text/plain"} - ) - fetcher = LinkContentFetcher() - streams = fetcher.run(urls=["https://www.example.com"])["streams"] - first_stream = streams[0] - assert first_stream.data == correct_response - assert first_stream.metadata["content_type"] == "text/plain" - - @pytest.mark.unit - def test_run_html(self): - correct_response = b"

Example test response

" - with patch("haystack.preview.components.fetchers.link_content.requests") as mock_run: - mock_run.get.return_value = Mock( - status_code=200, text="

Example test response

", headers={"Content-Type": "text/html"} - ) - fetcher = LinkContentFetcher() - streams = fetcher.run(urls=["https://www.example.com"])["streams"] - first_stream = streams[0] - assert first_stream.data == correct_response - assert first_stream.metadata["content_type"] == "text/html" - - @pytest.mark.unit - def test_run_binary(self, test_files_path): - file_bytes = open(test_files_path / "pdf" / "sample_pdf_1.pdf", "rb").read() - with patch("haystack.preview.components.fetchers.link_content.requests") as mock_run: - mock_run.get.return_value = Mock( - status_code=200, content=file_bytes, headers={"Content-Type": "application/pdf"} - ) - fetcher = LinkContentFetcher() - streams = fetcher.run(urls=["https://www.example.com"])["streams"] - first_stream = streams[0] - assert first_stream.data == file_bytes - assert first_stream.metadata["content_type"] == "application/pdf" - - @pytest.mark.unit - def test_run_bad_status_code(self): - empty_byte_stream = b"" - fetcher = LinkContentFetcher(raise_on_failure=False) - mock_response = Mock(status_code=403) - with patch("haystack.preview.components.fetchers.link_content.requests") as mock_run: - mock_run.get.return_value = mock_response - streams = fetcher.run(urls=["https://www.example.com"])["streams"] - - # empty byte stream is returned because raise_on_failure is False - assert len(streams) == 1 - first_stream = streams[0] - assert first_stream.data == empty_byte_stream - assert first_stream.metadata["content_type"] == "text/html" - - @pytest.mark.integration - def test_link_content_fetcher_html(self): - fetcher = LinkContentFetcher() - streams = fetcher.run([HTML_URL])["streams"] - first_stream = streams[0] - assert "Haystack" in first_stream.data.decode("utf-8") - assert first_stream.metadata["content_type"] == "text/html" - assert "url" in first_stream.metadata and first_stream.metadata["url"] == HTML_URL - - @pytest.mark.integration - def test_link_content_fetcher_text(self): - fetcher = LinkContentFetcher() - streams = fetcher.run([TEXT_URL])["streams"] - first_stream = streams[0] - assert "Haystack" in first_stream.data.decode("utf-8") - assert first_stream.metadata["content_type"] == "text/plain" - assert "url" in first_stream.metadata and first_stream.metadata["url"] == TEXT_URL - - @pytest.mark.integration - def test_link_content_fetcher_pdf(self): - fetcher = LinkContentFetcher() - streams = fetcher.run([PDF_URL])["streams"] - assert len(streams) == 1 - first_stream = streams[0] - assert first_stream.metadata["content_type"] in ("application/octet-stream", "application/pdf") - assert "url" in first_stream.metadata and first_stream.metadata["url"] == PDF_URL - - @pytest.mark.integration - def test_link_content_fetcher_multiple_different_content_types(self): - """ - This test is to ensure that the fetcher can handle a list of URLs that contain different content types. - """ - fetcher = LinkContentFetcher() - streams = fetcher.run([PDF_URL, HTML_URL])["streams"] - assert len(streams) == 2 - for stream in streams: - assert stream.metadata["content_type"] in ("text/html", "application/pdf", "application/octet-stream") - if stream.metadata["content_type"] == "text/html": - assert "Haystack" in stream.data.decode("utf-8") - elif stream.metadata["content_type"] == "application/pdf": - assert len(stream.data) > 0 - - @pytest.mark.integration - def test_link_content_fetcher_multiple_html_streams(self): - """ - This test is to ensure that the fetcher can handle a list of URLs that contain different content types, - and that we have two html streams. - """ - - fetcher = LinkContentFetcher() - streams = fetcher.run([PDF_URL, HTML_URL, "https://google.com"])["streams"] - assert len(streams) == 3 - for stream in streams: - assert stream.metadata["content_type"] in ("text/html", "application/pdf", "application/octet-stream") - if stream.metadata["content_type"] == "text/html": - assert "Haystack" in stream.data.decode("utf-8") or "Google" in stream.data.decode("utf-8") - elif stream.metadata["content_type"] == "application/pdf": - assert len(stream.data) > 0 - - @pytest.mark.integration - def test_mix_of_good_and_failed_requests(self): - """ - This test is to ensure that the fetcher can handle a list of URLs that contain URLs that fail to be fetched. - In such a case, the fetcher should return the content of the URLs that were successfully fetched and not raise - an exception. - """ - fetcher = LinkContentFetcher() - result = fetcher.run(["https://non_existent_website_dot.com/", "https://www.google.com/"]) - assert len(result["streams"]) == 1 - first_stream = result["streams"][0] - assert first_stream.metadata["content_type"] == "text/html" - - @pytest.mark.integration - def test_bad_request_exception_raised(self): - """ - This test is to ensure that the fetcher raises an exception when a single bad request is made and it is configured to - do so. - """ - fetcher = LinkContentFetcher() - with pytest.raises(requests.exceptions.ConnectionError): - fetcher.run(["https://non_existent_website_dot.com/"]) diff --git a/test/preview/components/generators/chat/__init__.py b/test/preview/components/generators/chat/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/test/preview/components/generators/chat/conftest.py b/test/preview/components/generators/chat/conftest.py deleted file mode 100644 index 7a6e7a0fba..0000000000 --- a/test/preview/components/generators/chat/conftest.py +++ /dev/null @@ -1,11 +0,0 @@ -import pytest - -from haystack.preview.dataclasses import ChatMessage - - -@pytest.fixture -def chat_messages(): - return [ - ChatMessage.from_system("You are a helpful assistant speaking A2 level of English"), - ChatMessage.from_user("Tell me about Berlin"), - ] diff --git a/test/preview/components/generators/chat/test_hugging_face_tgi.py b/test/preview/components/generators/chat/test_hugging_face_tgi.py deleted file mode 100644 index 35de294176..0000000000 --- a/test/preview/components/generators/chat/test_hugging_face_tgi.py +++ /dev/null @@ -1,317 +0,0 @@ -from unittest.mock import patch, MagicMock, Mock - -import pytest -from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token, StreamDetails, FinishReason -from huggingface_hub.utils import RepositoryNotFoundError - -from haystack.preview.components.generators.chat import HuggingFaceTGIChatGenerator - -from haystack.preview.dataclasses import StreamingChunk, ChatMessage - - -@pytest.fixture -def mock_check_valid_model(): - with patch( - "haystack.preview.components.generators.chat.hugging_face_tgi.check_valid_model", MagicMock(return_value=None) - ) as mock: - yield mock - - -@pytest.fixture -def mock_text_generation(): - with patch("huggingface_hub.InferenceClient.text_generation", autospec=True) as mock_text_generation: - mock_response = Mock() - mock_response.generated_text = "I'm fine, thanks." - details = Mock() - details.finish_reason = MagicMock(field1="value") - details.tokens = [1, 2, 3] - mock_response.details = details - mock_text_generation.return_value = mock_response - yield mock_text_generation - - -# used to test serialization of streaming_callback -def streaming_callback_handler(x): - return x - - -class TestHuggingFaceTGIChatGenerator: - @pytest.mark.unit - def test_initialize_with_valid_model_and_generation_parameters(self, mock_check_valid_model, mock_auto_tokenizer): - model = "HuggingFaceH4/zephyr-7b-alpha" - generation_kwargs = {"n": 1} - stop_words = ["stop"] - streaming_callback = None - - generator = HuggingFaceTGIChatGenerator( - model=model, - generation_kwargs=generation_kwargs, - stop_words=stop_words, - streaming_callback=streaming_callback, - ) - generator.warm_up() - - assert generator.generation_kwargs == {**generation_kwargs, **{"stop_sequences": ["stop"]}} - assert generator.tokenizer is not None - assert generator.client is not None - assert generator.streaming_callback == streaming_callback - - @pytest.mark.unit - def test_to_dict(self, mock_check_valid_model): - # Initialize the HuggingFaceTGIChatGenerator object with valid parameters - generator = HuggingFaceTGIChatGenerator( - model="NousResearch/Llama-2-7b-chat-hf", - token="token", - generation_kwargs={"n": 5}, - stop_words=["stop", "words"], - streaming_callback=lambda x: x, - ) - - # Call the to_dict method - result = generator.to_dict() - init_params = result["init_parameters"] - - # Assert that the init_params dictionary contains the expected keys and values - assert init_params["model"] == "NousResearch/Llama-2-7b-chat-hf" - assert init_params["token"] is None - assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"]} - - @pytest.mark.unit - def test_from_dict(self, mock_check_valid_model): - generator = HuggingFaceTGIChatGenerator( - model="NousResearch/Llama-2-7b-chat-hf", - generation_kwargs={"n": 5}, - stop_words=["stop", "words"], - streaming_callback=streaming_callback_handler, - ) - # Call the to_dict method - result = generator.to_dict() - - generator_2 = HuggingFaceTGIChatGenerator.from_dict(result) - assert generator_2.model == "NousResearch/Llama-2-7b-chat-hf" - assert generator_2.generation_kwargs == {"n": 5, "stop_sequences": ["stop", "words"]} - assert generator_2.streaming_callback is streaming_callback_handler - - @pytest.mark.unit - def test_warm_up(self, mock_check_valid_model, mock_auto_tokenizer): - generator = HuggingFaceTGIChatGenerator() - generator.warm_up() - - # Assert that the tokenizer is now initialized - assert generator.tokenizer is not None - - @pytest.mark.unit - def test_warm_up_no_chat_template(self, mock_check_valid_model, mock_auto_tokenizer, caplog): - generator = HuggingFaceTGIChatGenerator(model="meta-llama/Llama-2-13b-chat-hf") - - # Set chat_template to None for this specific test - mock_auto_tokenizer.chat_template = None - generator.warm_up() - - # warning message should be logged - assert "The model 'meta-llama/Llama-2-13b-chat-hf' doesn't have a default chat_template" in caplog.text - - @pytest.mark.unit - def test_custom_chat_template( - self, chat_messages, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation - ): - custom_chat_template = "Here goes some Jinja template" - - # mocked method to check if we called apply_chat_template with the custom template - mock_auto_tokenizer.apply_chat_template = MagicMock(return_value="some_value") - - generator = HuggingFaceTGIChatGenerator(chat_template=custom_chat_template) - generator.warm_up() - - assert generator.chat_template == custom_chat_template - - generator.run(messages=chat_messages) - assert mock_auto_tokenizer.apply_chat_template.call_count == 1 - - # and we indeed called apply_chat_template with the custom template - _, kwargs = mock_auto_tokenizer.apply_chat_template.call_args - assert kwargs["chat_template"] == custom_chat_template - - @pytest.mark.unit - def test_initialize_with_invalid_model_path_or_url(self, mock_check_valid_model): - model = "invalid_model" - generation_kwargs = {"n": 1} - stop_words = ["stop"] - streaming_callback = None - - mock_check_valid_model.side_effect = ValueError("Invalid model path or url") - - with pytest.raises(ValueError): - HuggingFaceTGIChatGenerator( - model=model, - generation_kwargs=generation_kwargs, - stop_words=stop_words, - streaming_callback=streaming_callback, - ) - - @pytest.mark.unit - def test_initialize_with_invalid_url(self, mock_check_valid_model): - with pytest.raises(ValueError): - HuggingFaceTGIChatGenerator(model="NousResearch/Llama-2-7b-chat-hf", url="invalid_url") - - @pytest.mark.unit - def test_initialize_with_url_but_invalid_model(self, mock_check_valid_model): - # When custom TGI endpoint is used via URL, model must be provided and valid HuggingFace Hub model id - mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id") - with pytest.raises(RepositoryNotFoundError): - HuggingFaceTGIChatGenerator(model="invalid_model_id", url="https://some_chat_model.com") - - @pytest.mark.unit - def test_generate_text_response_with_valid_prompt_and_generation_parameters( - self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages - ): - model = "meta-llama/Llama-2-13b-chat-hf" - generation_kwargs = {"n": 1} - stop_words = ["stop"] - streaming_callback = None - - generator = HuggingFaceTGIChatGenerator( - model=model, - generation_kwargs=generation_kwargs, - stop_words=stop_words, - streaming_callback=streaming_callback, - ) - generator.warm_up() - - response = generator.run(messages=chat_messages) - - # check kwargs passed to text_generation - # note how n because it is not text generation parameter was not passed to text_generation - _, kwargs = mock_text_generation.call_args - assert kwargs == {"details": True, "stop_sequences": ["stop"]} - - assert isinstance(response, dict) - assert "replies" in response - assert isinstance(response["replies"], list) - assert len(response["replies"]) == 1 - assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - - @pytest.mark.unit - def test_generate_multiple_text_responses_with_valid_prompt_and_generation_parameters( - self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages - ): - model = "meta-llama/Llama-2-13b-chat-hf" - token = None - generation_kwargs = {"n": 3} - stop_words = ["stop"] - streaming_callback = None - - generator = HuggingFaceTGIChatGenerator( - model=model, - token=token, - generation_kwargs=generation_kwargs, - stop_words=stop_words, - streaming_callback=streaming_callback, - ) - generator.warm_up() - - response = generator.run(chat_messages) - - # check kwargs passed to text_generation - _, kwargs = mock_text_generation.call_args - assert kwargs == {"details": True, "stop_sequences": ["stop"]} - - # note how n caused n replies to be generated - assert isinstance(response, dict) - assert "replies" in response - assert isinstance(response["replies"], list) - assert len(response["replies"]) == 3 - assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - - @pytest.mark.unit - def test_generate_text_with_stop_words( - self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages - ): - generator = HuggingFaceTGIChatGenerator() - generator.warm_up() - - stop_words = ["stop", "words"] - - # Generate text response with stop words - response = generator.run(chat_messages, generation_kwargs={"stop_words": stop_words}) - - # check kwargs passed to text_generation - # we translate stop_words to stop_sequences - _, kwargs = mock_text_generation.call_args - assert kwargs == {"details": True, "stop_sequences": ["stop", "words"]} - - # Assert that the response contains the generated replies - assert "replies" in response - assert isinstance(response["replies"], list) - assert len(response["replies"]) > 0 - assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - - @pytest.mark.unit - def test_generate_text_with_custom_generation_parameters( - self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages - ): - # Create an instance of HuggingFaceRemoteGenerator with no generation parameters - generator = HuggingFaceTGIChatGenerator() - generator.warm_up() - - # but then we pass them in run - generation_kwargs = {"temperature": 0.8, "max_new_tokens": 100} - response = generator.run(chat_messages, generation_kwargs=generation_kwargs) - - # again check kwargs passed to text_generation - _, kwargs = mock_text_generation.call_args - assert kwargs == {"details": True, "max_new_tokens": 100, "stop_sequences": [], "temperature": 0.8} - - # Assert that the response contains the generated replies and the right response - assert "replies" in response - assert isinstance(response["replies"], list) - assert len(response["replies"]) > 0 - assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - assert response["replies"][0].content == "I'm fine, thanks." - - @pytest.mark.unit - def test_generate_text_with_streaming_callback( - self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, chat_messages - ): - streaming_call_count = 0 - - # Define the streaming callback function - def streaming_callback_fn(chunk: StreamingChunk): - nonlocal streaming_call_count - streaming_call_count += 1 - assert isinstance(chunk, StreamingChunk) - - # Create an instance of HuggingFaceRemoteGenerator - generator = HuggingFaceTGIChatGenerator(streaming_callback=streaming_callback_fn) - generator.warm_up() - - # Create a fake streamed response - # self needed here, don't remove - def mock_iter(self): - yield TextGenerationStreamResponse( - generated_text=None, token=Token(id=1, text="I'm fine, thanks.", logprob=0.0, special=False) - ) - yield TextGenerationStreamResponse( - generated_text=None, - token=Token(id=1, text="Ok bye", logprob=0.0, special=False), - details=StreamDetails(finish_reason=FinishReason.Length, generated_tokens=5), - ) - - mock_response = Mock(**{"__iter__": mock_iter}) - mock_text_generation.return_value = mock_response - - # Generate text response with streaming callback - response = generator.run(chat_messages) - - # check kwargs passed to text_generation - _, kwargs = mock_text_generation.call_args - assert kwargs == {"details": True, "stop_sequences": [], "stream": True} - - # Assert that the streaming callback was called twice - assert streaming_call_count == 2 - - # Assert that the response contains the generated replies - assert "replies" in response - assert isinstance(response["replies"], list) - assert len(response["replies"]) > 0 - assert [isinstance(reply, ChatMessage) for reply in response["replies"]] diff --git a/test/preview/components/generators/chat/test_openai.py b/test/preview/components/generators/chat/test_openai.py deleted file mode 100644 index 9535bc14d7..0000000000 --- a/test/preview/components/generators/chat/test_openai.py +++ /dev/null @@ -1,330 +0,0 @@ -import os -from unittest.mock import patch, Mock - -import openai -import pytest - -from haystack.preview.components.generators.chat import GPTChatGenerator -from haystack.preview.components.generators.utils import default_streaming_callback -from haystack.preview.dataclasses import ChatMessage, StreamingChunk - - -@pytest.fixture -def mock_chat_completion(): - """ - Mock the OpenAI API completion response and reuse it for tests - """ - with patch("openai.ChatCompletion.create", autospec=True) as mock_chat_completion_create: - # mimic the response from the OpenAI API - mock_choice = Mock() - mock_choice.index = 0 - mock_choice.finish_reason = "stop" - - mock_message = Mock() - mock_message.content = "I'm fine, thanks. How are you?" - mock_message.role = "user" - - mock_choice.message = mock_message - - mock_response = Mock() - mock_response.model = "gpt-3.5-turbo" - mock_response.usage = Mock() - mock_response.usage.items.return_value = [ - ("prompt_tokens", 57), - ("completion_tokens", 40), - ("total_tokens", 97), - ] - mock_response.choices = [mock_choice] - mock_chat_completion_create.return_value = mock_response - yield mock_chat_completion_create - - -def streaming_chunk(content: str): - """ - Mock chunks of streaming responses from the OpenAI API - """ - # mimic the chunk response from the OpenAI API - mock_choice = Mock() - mock_choice.index = 0 - mock_choice.delta.content = content - mock_choice.finish_reason = "stop" - - mock_response = Mock() - mock_response.choices = [mock_choice] - mock_response.model = "gpt-3.5-turbo" - mock_response.usage = Mock() - mock_response.usage.items.return_value = [("prompt_tokens", 57), ("completion_tokens", 40), ("total_tokens", 97)] - return mock_response - - -@pytest.fixture -def chat_messages(): - return [ - ChatMessage.from_system("You are a helpful assistant"), - ChatMessage.from_user("What's the capital of France"), - ] - - -class TestGPTChatGenerator: - @pytest.mark.unit - def test_init_default(self): - component = GPTChatGenerator(api_key="test-api-key") - assert openai.api_key == "test-api-key" - assert component.model_name == "gpt-3.5-turbo" - assert component.streaming_callback is None - assert component.api_base_url == "https://api.openai.com/v1" - assert openai.api_base == "https://api.openai.com/v1" - assert not component.generation_kwargs - - @pytest.mark.unit - def test_init_fail_wo_api_key(self, monkeypatch): - openai.api_key = None - monkeypatch.delenv("OPENAI_API_KEY", raising=False) - with pytest.raises(ValueError, match="GPTChatGenerator expects an OpenAI API key"): - GPTChatGenerator() - - @pytest.mark.unit - def test_init_with_parameters(self): - component = GPTChatGenerator( - api_key="test-api-key", - model_name="gpt-4", - streaming_callback=default_streaming_callback, - api_base_url="test-base-url", - generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, - ) - assert openai.api_key == "test-api-key" - assert component.model_name == "gpt-4" - assert component.streaming_callback is default_streaming_callback - assert component.api_base_url == "test-base-url" - assert openai.api_base == "test-base-url" - assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} - - @pytest.mark.unit - def test_to_dict_default(self): - component = GPTChatGenerator(api_key="test-api-key") - data = component.to_dict() - assert data == { - "type": "haystack.preview.components.generators.chat.openai.GPTChatGenerator", - "init_parameters": { - "model_name": "gpt-3.5-turbo", - "streaming_callback": None, - "api_base_url": "https://api.openai.com/v1", - "generation_kwargs": {}, - }, - } - - @pytest.mark.unit - def test_to_dict_with_parameters(self): - component = GPTChatGenerator( - api_key="test-api-key", - model_name="gpt-4", - streaming_callback=default_streaming_callback, - api_base_url="test-base-url", - generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, - ) - data = component.to_dict() - assert data == { - "type": "haystack.preview.components.generators.chat.openai.GPTChatGenerator", - "init_parameters": { - "model_name": "gpt-4", - "api_base_url": "test-base-url", - "streaming_callback": "haystack.preview.components.generators.utils.default_streaming_callback", - "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, - }, - } - - @pytest.mark.unit - def test_to_dict_with_lambda_streaming_callback(self): - component = GPTChatGenerator( - api_key="test-api-key", - model_name="gpt-4", - streaming_callback=lambda x: x, - api_base_url="test-base-url", - generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, - ) - data = component.to_dict() - assert data == { - "type": "haystack.preview.components.generators.chat.openai.GPTChatGenerator", - "init_parameters": { - "model_name": "gpt-4", - "api_base_url": "test-base-url", - "streaming_callback": "chat.test_openai.", - "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, - }, - } - - @pytest.mark.unit - def test_from_dict(self, monkeypatch): - monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") - data = { - "type": "haystack.preview.components.generators.chat.openai.GPTChatGenerator", - "init_parameters": { - "model_name": "gpt-4", - "api_base_url": "test-base-url", - "streaming_callback": "haystack.preview.components.generators.utils.default_streaming_callback", - "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, - }, - } - component = GPTChatGenerator.from_dict(data) - assert component.model_name == "gpt-4" - assert component.streaming_callback is default_streaming_callback - assert component.api_base_url == "test-base-url" - assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} - - @pytest.mark.unit - def test_from_dict_fail_wo_env_var(self, monkeypatch): - openai.api_key = None - monkeypatch.delenv("OPENAI_API_KEY", raising=False) - data = { - "type": "haystack.preview.components.generators.chat.openai.GPTChatGenerator", - "init_parameters": { - "model_name": "gpt-4", - "api_base_url": "test-base-url", - "streaming_callback": "haystack.preview.components.generators.utils.default_streaming_callback", - "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, - }, - } - with pytest.raises(ValueError, match="GPTChatGenerator expects an OpenAI API key"): - GPTChatGenerator.from_dict(data) - - @pytest.mark.unit - def test_run(self, chat_messages, mock_chat_completion): - component = GPTChatGenerator(api_key="test-api-key") - response = component.run(chat_messages) - - # check that the component returns the correct ChatMessage response - assert isinstance(response, dict) - assert "replies" in response - assert isinstance(response["replies"], list) - assert len(response["replies"]) == 1 - assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - - @pytest.mark.unit - def test_run_with_params(self, chat_messages, mock_chat_completion): - component = GPTChatGenerator(api_key="test-api-key", generation_kwargs={"max_tokens": 10, "temperature": 0.5}) - response = component.run(chat_messages) - - # check that the component calls the OpenAI API with the correct parameters - _, kwargs = mock_chat_completion.call_args - assert kwargs["max_tokens"] == 10 - assert kwargs["temperature"] == 0.5 - - # check that the component returns the correct response - assert isinstance(response, dict) - assert "replies" in response - assert isinstance(response["replies"], list) - assert len(response["replies"]) == 1 - assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - - @pytest.mark.unit - def test_run_streaming(self, chat_messages, mock_chat_completion): - streaming_call_count = 0 - - # Define the streaming callback function and assert that it is called with StreamingChunk objects - def streaming_callback_fn(chunk: StreamingChunk): - nonlocal streaming_call_count - streaming_call_count += 1 - assert isinstance(chunk, StreamingChunk) - - generator = GPTChatGenerator(api_key="test-api-key", streaming_callback=streaming_callback_fn) - - # Create a fake streamed response - # self needed here, don't remove - def mock_iter(self): - yield streaming_chunk("Hello") - yield streaming_chunk("How are you?") - - mock_response = Mock(**{"__iter__": mock_iter}) - mock_chat_completion.return_value = mock_response - - response = generator.run(chat_messages) - - # Assert that the streaming callback was called twice - assert streaming_call_count == 2 - - # Assert that the response contains the generated replies - assert "replies" in response - assert isinstance(response["replies"], list) - assert len(response["replies"]) > 0 - assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - - @pytest.mark.unit - def test_check_abnormal_completions(self, caplog): - component = GPTChatGenerator(api_key="test-api-key") - messages = [ - ChatMessage.from_assistant( - "", metadata={"finish_reason": "content_filter" if i % 2 == 0 else "length", "index": i} - ) - for i, _ in enumerate(range(4)) - ] - - for m in messages: - component._check_finish_reason(m) - - # check truncation warning - message_template = ( - "The completion for index {index} has been truncated before reaching a natural stopping point. " - "Increase the max_tokens parameter to allow for longer completions." - ) - - for index in [1, 3]: - assert caplog.records[index].message == message_template.format(index=index) - - # check content filter warning - message_template = "The completion for index {index} has been truncated due to the content filter." - for index in [0, 2]: - assert caplog.records[index].message == message_template.format(index=index) - - @pytest.mark.skipif( - not os.environ.get("OPENAI_API_KEY", None), - reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", - ) - @pytest.mark.integration - def test_live_run(self): - chat_messages = [ChatMessage.from_user("What's the capital of France")] - component = GPTChatGenerator(api_key=os.environ.get("OPENAI_API_KEY"), generation_kwargs={"n": 1}) - results = component.run(chat_messages) - assert len(results["replies"]) == 1 - message: ChatMessage = results["replies"][0] - assert "Paris" in message.content - assert "gpt-3.5" in message.metadata["model"] - assert message.metadata["finish_reason"] == "stop" - - @pytest.mark.skipif( - not os.environ.get("OPENAI_API_KEY", None), - reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", - ) - @pytest.mark.integration - def test_live_run_wrong_model(self, chat_messages): - component = GPTChatGenerator(model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY")) - with pytest.raises(openai.InvalidRequestError, match="The model `something-obviously-wrong` does not exist"): - component.run(chat_messages) - - @pytest.mark.skipif( - not os.environ.get("OPENAI_API_KEY", None), - reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", - ) - @pytest.mark.integration - def test_live_run_streaming(self): - class Callback: - def __init__(self): - self.responses = "" - self.counter = 0 - - def __call__(self, chunk: StreamingChunk) -> None: - self.counter += 1 - self.responses += chunk.content if chunk.content else "" - - callback = Callback() - component = GPTChatGenerator(os.environ.get("OPENAI_API_KEY"), streaming_callback=callback) - results = component.run([ChatMessage.from_user("What's the capital of France?")]) - - assert len(results["replies"]) == 1 - message: ChatMessage = results["replies"][0] - assert "Paris" in message.content - - assert "gpt-3.5" in message.metadata["model"] - assert message.metadata["finish_reason"] == "stop" - - assert callback.counter > 1 - assert "Paris" in callback.responses diff --git a/test/preview/components/generators/conftest.py b/test/preview/components/generators/conftest.py deleted file mode 100644 index 435b36ea07..0000000000 --- a/test/preview/components/generators/conftest.py +++ /dev/null @@ -1,21 +0,0 @@ -from unittest.mock import patch, MagicMock - -import pytest - - -@pytest.fixture -def mock_auto_tokenizer(): - """ - In the original mock_auto_tokenizer fixture, we were mocking the transformers.AutoTokenizer.from_pretrained - method directly, but we were not providing a return value for this method. Therefore, when from_pretrained - was called within HuggingFaceTGIChatGenerator, it returned None because that's the default behavior of a - MagicMock object when a return value isn't specified. - - We will update the mock_auto_tokenizer fixture to return a MagicMock object when from_pretrained is called - in another PR. For now, we will use this fixture to mock the AutoTokenizer.from_pretrained method. - """ - - with patch("transformers.AutoTokenizer.from_pretrained", autospec=True) as mock_from_pretrained: - mock_tokenizer = MagicMock() - mock_from_pretrained.return_value = mock_tokenizer - yield mock_tokenizer diff --git a/test/preview/components/generators/test_cohere_generators.py b/test/preview/components/generators/test_cohere_generators.py deleted file mode 100644 index cd1f9cb2a4..0000000000 --- a/test/preview/components/generators/test_cohere_generators.py +++ /dev/null @@ -1,172 +0,0 @@ -import os - -import pytest -import cohere - -from haystack.preview.components.generators import CohereGenerator - - -def default_streaming_callback(chunk): - """ - Default callback function for streaming responses from Cohere API. - Prints the tokens of the first completion to stdout as soon as they are received and returns the chunk unchanged. - """ - print(chunk.text, flush=True, end="") - - -class TestGPTGenerator: - def test_init_default(self): - component = CohereGenerator(api_key="test-api-key") - assert component.api_key == "test-api-key" - assert component.model_name == "command" - assert component.streaming_callback is None - assert component.api_base_url == cohere.COHERE_API_URL - assert component.model_parameters == {} - - def test_init_with_parameters(self): - callback = lambda x: x - component = CohereGenerator( - api_key="test-api-key", - model_name="command-light", - max_tokens=10, - some_test_param="test-params", - streaming_callback=callback, - api_base_url="test-base-url", - ) - assert component.api_key == "test-api-key" - assert component.model_name == "command-light" - assert component.streaming_callback == callback - assert component.api_base_url == "test-base-url" - assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"} - - def test_to_dict_default(self): - component = CohereGenerator(api_key="test-api-key") - data = component.to_dict() - assert data == { - "type": "haystack.preview.components.generators.cohere.CohereGenerator", - "init_parameters": { - "model_name": "command", - "streaming_callback": None, - "api_base_url": cohere.COHERE_API_URL, - }, - } - - def test_to_dict_with_parameters(self): - component = CohereGenerator( - api_key="test-api-key", - model_name="command-light", - max_tokens=10, - some_test_param="test-params", - streaming_callback=default_streaming_callback, - api_base_url="test-base-url", - ) - data = component.to_dict() - assert data == { - "type": "haystack.preview.components.generators.cohere.CohereGenerator", - "init_parameters": { - "model_name": "command-light", - "max_tokens": 10, - "some_test_param": "test-params", - "api_base_url": "test-base-url", - "streaming_callback": "test_cohere_generators.default_streaming_callback", - }, - } - - def test_to_dict_with_lambda_streaming_callback(self): - component = CohereGenerator( - api_key="test-api-key", - model_name="command", - max_tokens=10, - some_test_param="test-params", - streaming_callback=lambda x: x, - api_base_url="test-base-url", - ) - data = component.to_dict() - assert data == { - "type": "haystack.preview.components.generators.cohere.CohereGenerator", - "init_parameters": { - "model_name": "command", - "streaming_callback": "test_cohere_generators.", - "api_base_url": "test-base-url", - "max_tokens": 10, - "some_test_param": "test-params", - }, - } - - def test_from_dict(self, monkeypatch): - monkeypatch.setenv("COHERE_API_KEY", "test-key") - data = { - "type": "haystack.preview.components.generators.cohere.CohereGenerator", - "init_parameters": { - "model_name": "command", - "max_tokens": 10, - "some_test_param": "test-params", - "api_base_url": "test-base-url", - "streaming_callback": "test_cohere_generators.default_streaming_callback", - }, - } - component = CohereGenerator.from_dict(data) - assert component.api_key == "test-key" - assert component.model_name == "command" - assert component.streaming_callback == default_streaming_callback - assert component.api_base_url == "test-base-url" - assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"} - - def test_check_truncated_answers(self, caplog): - component = CohereGenerator(api_key="test-api-key") - metadata = [{"finish_reason": "MAX_TOKENS"}] - component._check_truncated_answers(metadata) - assert caplog.records[0].message == ( - "Responses have been truncated before reaching a natural stopping point. " - "Increase the max_tokens parameter to allow for longer completions." - ) - - @pytest.mark.skipif( - not os.environ.get("COHERE_API_KEY", None), - reason="Export an env var called CO_API_KEY containing the Cohere API key to run this test.", - ) - @pytest.mark.integration - def test_cohere_generator_run(self): - component = CohereGenerator(api_key=os.environ.get("COHERE_API_KEY")) - results = component.run(prompt="What's the capital of France?") - assert len(results["replies"]) == 1 - assert "Paris" in results["replies"][0] - assert len(results["metadata"]) == 1 - assert results["metadata"][0]["finish_reason"] == "COMPLETE" - - @pytest.mark.skipif( - not os.environ.get("COHERE_API_KEY", None), - reason="Export an env var called COHERE_API_KEY containing the Cohere API key to run this test.", - ) - @pytest.mark.integration - def test_cohere_generator_run_wrong_model_name(self): - component = CohereGenerator(model_name="something-obviously-wrong", api_key=os.environ.get("COHERE_API_KEY")) - with pytest.raises( - cohere.CohereAPIError, - match="model not found, make sure the correct model ID was used and that you have access to the model.", - ): - component.run(prompt="What's the capital of France?") - - @pytest.mark.skipif( - not os.environ.get("COHERE_API_KEY", None), - reason="Export an env var called COHERE_API_KEY containing the Cohere API key to run this test.", - ) - @pytest.mark.integration - def test_cohere_generator_run_streaming(self): - class Callback: - def __init__(self): - self.responses = "" - - def __call__(self, chunk): - self.responses += chunk.text - return chunk - - callback = Callback() - component = CohereGenerator(os.environ.get("COHERE_API_KEY"), streaming_callback=callback) - results = component.run(prompt="What's the capital of France?") - - assert len(results["replies"]) == 1 - assert "Paris" in results["replies"][0] - assert len(results["metadata"]) == 1 - assert results["metadata"][0]["finish_reason"] == "COMPLETE" - assert callback.responses == results["replies"][0] diff --git a/test/preview/components/generators/test_hf_utils.py b/test/preview/components/generators/test_hf_utils.py deleted file mode 100644 index f69a099743..0000000000 --- a/test/preview/components/generators/test_hf_utils.py +++ /dev/null @@ -1,50 +0,0 @@ -import pytest - -from haystack.preview.components.generators.hf_utils import check_generation_params - - -@pytest.mark.unit -def test_empty_dictionary(): - # no exception raised - check_generation_params({}) - - -@pytest.mark.unit -def test_valid_generation_parameters(): - # these are valid parameters - kwargs = {"max_new_tokens": 100, "temperature": 0.8} - additional_accepted_params = None - check_generation_params(kwargs, additional_accepted_params) - - -@pytest.mark.unit -def test_invalid_generation_parameters(): - # these are invalid parameters - kwargs = {"invalid_param": "value"} - additional_accepted_params = None - with pytest.raises(ValueError): - check_generation_params(kwargs, additional_accepted_params) - - -@pytest.mark.unit -def test_additional_accepted_params_empty_list(): - kwargs = {"temperature": 0.8} - additional_accepted_params = [] - check_generation_params(kwargs, additional_accepted_params) - - -@pytest.mark.unit -def test_additional_accepted_params_known_parameter(): - # both are valid parameters - kwargs = {"temperature": 0.8} - additional_accepted_params = ["max_new_tokens"] - check_generation_params(kwargs, additional_accepted_params) - - -@pytest.mark.unit -def test_additional_accepted_params_unknown_parameter(): - kwargs = {"strange_param": "value"} - additional_accepted_params = ["strange_param"] - # Although strange_param is not generation param the check_generation_params - # does not raise exception because strange_param is passed as additional_accepted_params - check_generation_params(kwargs, additional_accepted_params) diff --git a/test/preview/components/generators/test_hugging_face_local_generator.py b/test/preview/components/generators/test_hugging_face_local_generator.py deleted file mode 100644 index 367a54c640..0000000000 --- a/test/preview/components/generators/test_hugging_face_local_generator.py +++ /dev/null @@ -1,349 +0,0 @@ -# pylint: disable=too-many-public-methods -from unittest.mock import patch, Mock - -import pytest -import torch - -from haystack.preview.components.generators.hugging_face_local import HuggingFaceLocalGenerator, StopWordsCriteria - - -class TestHuggingFaceLocalGenerator: - @pytest.mark.unit - @patch("haystack.preview.components.generators.hugging_face_local.model_info") - def test_init_default(self, model_info_mock): - model_info_mock.return_value.pipeline_tag = "text2text-generation" - generator = HuggingFaceLocalGenerator() - - assert generator.huggingface_pipeline_kwargs == { - "model": "google/flan-t5-base", - "task": "text2text-generation", - "token": None, - } - assert generator.generation_kwargs == {} - assert generator.pipeline is None - - @pytest.mark.unit - def test_init_custom_token(self): - generator = HuggingFaceLocalGenerator( - model_name_or_path="google/flan-t5-base", task="text2text-generation", token="test-token" - ) - - assert generator.huggingface_pipeline_kwargs == { - "model": "google/flan-t5-base", - "task": "text2text-generation", - "token": "test-token", - } - - @pytest.mark.unit - def test_init_custom_device(self): - generator = HuggingFaceLocalGenerator( - model_name_or_path="google/flan-t5-base", task="text2text-generation", device="cuda:0" - ) - - assert generator.huggingface_pipeline_kwargs == { - "model": "google/flan-t5-base", - "task": "text2text-generation", - "token": None, - "device": "cuda:0", - } - - @pytest.mark.unit - def test_init_task_parameter(self): - generator = HuggingFaceLocalGenerator(task="text2text-generation") - - assert generator.huggingface_pipeline_kwargs == { - "model": "google/flan-t5-base", - "task": "text2text-generation", - "token": None, - } - - @pytest.mark.unit - def test_init_task_in_huggingface_pipeline_kwargs(self): - generator = HuggingFaceLocalGenerator(huggingface_pipeline_kwargs={"task": "text2text-generation"}) - - assert generator.huggingface_pipeline_kwargs == { - "model": "google/flan-t5-base", - "task": "text2text-generation", - "token": None, - } - - @pytest.mark.unit - @patch("haystack.preview.components.generators.hugging_face_local.model_info") - def test_init_task_inferred_from_model_name(self, model_info_mock): - model_info_mock.return_value.pipeline_tag = "text2text-generation" - generator = HuggingFaceLocalGenerator(model_name_or_path="google/flan-t5-base") - - assert generator.huggingface_pipeline_kwargs == { - "model": "google/flan-t5-base", - "task": "text2text-generation", - "token": None, - } - - @pytest.mark.unit - def test_init_invalid_task(self): - with pytest.raises(ValueError, match="is not supported."): - HuggingFaceLocalGenerator(task="text-classification") - - @pytest.mark.unit - def test_init_huggingface_pipeline_kwargs_override_other_parameters(self): - """ - huggingface_pipeline_kwargs represent the main configuration of this component. - If they are provided, they should override other init parameters. - """ - - huggingface_pipeline_kwargs = { - "model": "gpt2", - "task": "text-generation", - "device": "cuda:0", - "token": "another-test-token", - } - - generator = HuggingFaceLocalGenerator( - model_name_or_path="google/flan-t5-base", - task="text2text-generation", - device="cpu", - token="test-token", - huggingface_pipeline_kwargs=huggingface_pipeline_kwargs, - ) - - assert generator.huggingface_pipeline_kwargs == huggingface_pipeline_kwargs - - @pytest.mark.unit - def test_init_generation_kwargs(self): - generator = HuggingFaceLocalGenerator(task="text2text-generation", generation_kwargs={"max_new_tokens": 100}) - - assert generator.generation_kwargs == {"max_new_tokens": 100} - - @pytest.mark.unit - def test_init_set_return_full_text(self): - """ - if not specified, return_full_text is set to False for text-generation task - (only generated text is returned, excluding prompt) - """ - generator = HuggingFaceLocalGenerator(task="text-generation") - - assert generator.generation_kwargs == {"return_full_text": False} - - @pytest.mark.unit - def test_init_fails_with_both_stopwords_and_stoppingcriteria(self): - with pytest.raises( - ValueError, - match="Found both the `stop_words` init parameter and the `stopping_criteria` key in `generation_kwargs`", - ): - HuggingFaceLocalGenerator( - task="text2text-generation", - stop_words=["coca", "cola"], - generation_kwargs={"stopping_criteria": "fake-stopping-criteria"}, - ) - - @pytest.mark.unit - @patch("haystack.preview.components.generators.hugging_face_local.model_info") - def test_to_dict_default(self, model_info_mock): - model_info_mock.return_value.pipeline_tag = "text2text-generation" - - component = HuggingFaceLocalGenerator() - data = component.to_dict() - - assert data == { - "type": "haystack.preview.components.generators.hugging_face_local.HuggingFaceLocalGenerator", - "init_parameters": { - "huggingface_pipeline_kwargs": { - "model": "google/flan-t5-base", - "task": "text2text-generation", - "token": None, - }, - "generation_kwargs": {}, - "stop_words": None, - }, - } - - @pytest.mark.unit - def test_to_dict_with_parameters(self): - component = HuggingFaceLocalGenerator( - model_name_or_path="gpt2", - task="text-generation", - device="cuda:0", - token="test-token", - generation_kwargs={"max_new_tokens": 100}, - stop_words=["coca", "cola"], - ) - data = component.to_dict() - - assert data == { - "type": "haystack.preview.components.generators.hugging_face_local.HuggingFaceLocalGenerator", - "init_parameters": { - "huggingface_pipeline_kwargs": { - "model": "gpt2", - "task": "text-generation", - "token": None, # we don't want serialize valid tokens - "device": "cuda:0", - }, - "generation_kwargs": {"max_new_tokens": 100, "return_full_text": False}, - "stop_words": ["coca", "cola"], - }, - } - - @pytest.mark.unit - @patch("haystack.preview.components.generators.hugging_face_local.pipeline") - def test_warm_up(self, pipeline_mock): - generator = HuggingFaceLocalGenerator( - model_name_or_path="google/flan-t5-base", task="text2text-generation", token="test-token" - ) - pipeline_mock.assert_not_called() - - generator.warm_up() - - pipeline_mock.assert_called_once_with( - model="google/flan-t5-base", task="text2text-generation", token="test-token" - ) - - @pytest.mark.unit - @patch("haystack.preview.components.generators.hugging_face_local.pipeline") - def test_warm_up_doesn_reload(self, pipeline_mock): - generator = HuggingFaceLocalGenerator( - model_name_or_path="google/flan-t5-base", task="text2text-generation", token="test-token" - ) - - pipeline_mock.assert_not_called() - - generator.warm_up() - generator.warm_up() - - pipeline_mock.assert_called_once() - - @pytest.mark.unit - def test_run(self): - generator = HuggingFaceLocalGenerator( - model_name_or_path="google/flan-t5-base", - task="text2text-generation", - generation_kwargs={"max_new_tokens": 100}, - ) - - # create the pipeline object (simulating the warm_up) - generator.pipeline = Mock(return_value=[{"generated_text": "Rome"}]) - - results = generator.run(prompt="What's the capital of Italy?") - - generator.pipeline.assert_called_once_with( - "What's the capital of Italy?", max_new_tokens=100, stopping_criteria=None - ) - assert results == {"replies": ["Rome"]} - - @pytest.mark.unit - @patch("haystack.preview.components.generators.hugging_face_local.pipeline") - def test_run_empty_prompt(self, pipeline_mock): - generator = HuggingFaceLocalGenerator( - model_name_or_path="google/flan-t5-base", - task="text2text-generation", - generation_kwargs={"max_new_tokens": 100}, - ) - - generator.warm_up() - - results = generator.run(prompt="") - - assert results == {"replies": []} - - @pytest.mark.unit - def test_run_with_generation_kwargs(self): - generator = HuggingFaceLocalGenerator( - model_name_or_path="google/flan-t5-base", - task="text2text-generation", - generation_kwargs={"max_new_tokens": 100}, - ) - - # create the pipeline object (simulating the warm_up) - generator.pipeline = Mock(return_value=[{"generated_text": "Rome"}]) - - generator.run(prompt="irrelevant", generation_kwargs={"max_new_tokens": 200, "temperature": 0.5}) - - generator.pipeline.assert_called_once_with( - "irrelevant", max_new_tokens=200, temperature=0.5, stopping_criteria=None - ) - - @pytest.mark.unit - def test_run_fails_without_warm_up(self): - generator = HuggingFaceLocalGenerator( - model_name_or_path="google/flan-t5-base", - task="text2text-generation", - generation_kwargs={"max_new_tokens": 100}, - ) - - with pytest.raises(RuntimeError, match="The generation model has not been loaded."): - generator.run(prompt="irrelevant") - - @pytest.mark.unit - def test_stop_words_criteria(self): - """ - Test that StopWordsCriteria will check stop word tokens in a continuous and sequential order - """ - # input ids for "unambiguously" - stop_words_id = torch.tensor([[73, 24621, 11937]]) - - # input ids for "This is ambiguously, but is unrelated." - input_ids1 = torch.tensor([[100, 19, 24621, 11937, 6, 68, 19, 73, 3897, 5]]) - # input ids for "This is unambiguously" - input_ids2 = torch.tensor([[100, 19, 73, 24621, 11937]]) - - # We used to implement stop words algorithm using the torch.isin function like this: - # `all(torch.isin(stop_words_id, input_ids1)[0])` - # However, this algorithm is not correct as it will return True for presence of "unambiguously" in input_ids1 - # and True for presence of "unambiguously" in input_ids2. This is because the algorithm will check - # if the stop word tokens are present in the input_ids, but it does not check if the stop word tokens are - # present in a continuous/sequential order. - - # In "This is ambiguously, but is unrelated." sentence the "un" token comes from "unrelated" and the - # "ambiguously" token comes from "ambiguously". The algorithm will return True for presence of - # "unambiguously" in input_ids1 which is not correct. - - stop_words_criteria = StopWordsCriteria(tokenizer=Mock(), stop_words=["mock data"]) - # because we are mocking the tokenizer, we need to set the stop words manually - stop_words_criteria.stop_ids = stop_words_id - - # this is the correct algorithm to check if the stop word tokens are present in a continuous and sequential order - # For the input_ids1, the stop word tokens are present BUT not in a continuous order - present_and_continuous = stop_words_criteria(input_ids1, scores=None) - assert not present_and_continuous - - # For the input_ids2, the stop word tokens are both present and in a continuous order - present_and_continuous = stop_words_criteria(input_ids2, scores=None) - assert present_and_continuous - - @pytest.mark.unit - @patch("haystack.preview.components.generators.hugging_face_local.pipeline") - @patch("haystack.preview.components.generators.hugging_face_local.StopWordsCriteria") - @patch("haystack.preview.components.generators.hugging_face_local.StoppingCriteriaList") - def test_warm_up_set_stopping_criteria_list( - self, pipeline_mock, stop_words_criteria_mock, stopping_criteria_list_mock - ): - """ - Test that warm_up method sets the `stopping_criteria_list` attribute - if `stop_words` is provided - """ - generator = HuggingFaceLocalGenerator( - model_name_or_path="google/flan-t5-base", task="text2text-generation", stop_words=["coca", "cola"] - ) - - generator.warm_up() - - stop_words_criteria_mock.assert_called_once() - stopping_criteria_list_mock.assert_called_once() - - assert hasattr(generator, "stopping_criteria_list") - - @pytest.mark.unit - def test_run_stop_words_removal(self): - """ - Test that stop words are removed from the generated text - (does not test stopping text generation) - """ - generator = HuggingFaceLocalGenerator( - model_name_or_path="google/flan-t5-base", task="text2text-generation", stop_words=["world"] - ) - - # create the pipeline object (simulating the warm_up) - generator.pipeline = Mock(return_value=[{"generated_text": "Hello world"}]) - - results = generator.run(prompt="irrelevant") - - assert results == {"replies": ["Hello"]} diff --git a/test/preview/components/generators/test_hugging_face_tgi.py b/test/preview/components/generators/test_hugging_face_tgi.py deleted file mode 100644 index 5fcbb304b1..0000000000 --- a/test/preview/components/generators/test_hugging_face_tgi.py +++ /dev/null @@ -1,295 +0,0 @@ -from unittest.mock import patch, MagicMock, Mock - -import pytest -from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token, StreamDetails, FinishReason -from huggingface_hub.utils import RepositoryNotFoundError - -from haystack.preview.components.generators import HuggingFaceTGIGenerator -from haystack.preview.dataclasses import StreamingChunk - - -@pytest.fixture -def mock_check_valid_model(): - with patch( - "haystack.preview.components.generators.hugging_face_tgi.check_valid_model", MagicMock(return_value=None) - ) as mock: - yield mock - - -@pytest.fixture -def mock_text_generation(): - with patch("huggingface_hub.InferenceClient.text_generation", autospec=True) as mock_text_generation: - mock_response = Mock() - mock_response.generated_text = "I'm fine, thanks." - details = Mock() - details.finish_reason = MagicMock(field1="value") - details.tokens = [1, 2, 3] - mock_response.details = details - mock_text_generation.return_value = mock_response - yield mock_text_generation - - -# used to test serialization of streaming_callback -def streaming_callback_handler(x): - return x - - -class TestHuggingFaceTGIGenerator: - @pytest.mark.unit - def test_initialize_with_valid_model_and_generation_parameters(self, mock_check_valid_model): - model = "HuggingFaceH4/zephyr-7b-alpha" - generation_kwargs = {"n": 1} - stop_words = ["stop"] - streaming_callback = None - - generator = HuggingFaceTGIGenerator( - model=model, - url=None, - token=None, - generation_kwargs=generation_kwargs, - stop_words=stop_words, - streaming_callback=streaming_callback, - ) - - assert generator.model == model - assert generator.generation_kwargs == {**generation_kwargs, **{"stop_sequences": ["stop"]}} - assert generator.tokenizer is None - assert generator.client is not None - assert generator.streaming_callback == streaming_callback - - @pytest.mark.unit - def test_to_dict(self, mock_check_valid_model): - # Initialize the HuggingFaceRemoteGenerator object with valid parameters - generator = HuggingFaceTGIGenerator( - token="token", generation_kwargs={"n": 5}, stop_words=["stop", "words"], streaming_callback=lambda x: x - ) - - # Call the to_dict method - result = generator.to_dict() - init_params = result["init_parameters"] - - # Assert that the init_params dictionary contains the expected keys and values - assert init_params["model"] == "mistralai/Mistral-7B-v0.1" - assert not init_params["token"] - assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"]} - - @pytest.mark.unit - def test_from_dict(self, mock_check_valid_model): - generator = HuggingFaceTGIGenerator( - model="mistralai/Mistral-7B-v0.1", - generation_kwargs={"n": 5}, - stop_words=["stop", "words"], - streaming_callback=streaming_callback_handler, - ) - # Call the to_dict method - result = generator.to_dict() - - # now deserialize, call from_dict - generator_2 = HuggingFaceTGIGenerator.from_dict(result) - assert generator_2.model == "mistralai/Mistral-7B-v0.1" - assert generator_2.generation_kwargs == {"n": 5, "stop_sequences": ["stop", "words"]} - assert generator_2.streaming_callback is streaming_callback_handler - - @pytest.mark.unit - def test_initialize_with_invalid_url(self, mock_check_valid_model): - with pytest.raises(ValueError): - HuggingFaceTGIGenerator(model="mistralai/Mistral-7B-v0.1", url="invalid_url") - - @pytest.mark.unit - def test_initialize_with_url_but_invalid_model(self, mock_check_valid_model): - # When custom TGI endpoint is used via URL, model must be provided and valid HuggingFace Hub model id - mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id") - with pytest.raises(RepositoryNotFoundError): - HuggingFaceTGIGenerator(model="invalid_model_id", url="https://some_chat_model.com") - - @pytest.mark.unit - def test_generate_text_response_with_valid_prompt_and_generation_parameters( - self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation - ): - model = "mistralai/Mistral-7B-v0.1" - - generation_kwargs = {"n": 1} - stop_words = ["stop"] - streaming_callback = None - - generator = HuggingFaceTGIGenerator( - model=model, - generation_kwargs=generation_kwargs, - stop_words=stop_words, - streaming_callback=streaming_callback, - ) - generator.warm_up() - - prompt = "Hello, how are you?" - response = generator.run(prompt) - - # check kwargs passed to text_generation - # note how n was not passed to text_generation - _, kwargs = mock_text_generation.call_args - assert kwargs == {"details": True, "stop_sequences": ["stop"]} - - assert isinstance(response, dict) - assert "replies" in response - assert "metadata" in response - assert isinstance(response["replies"], list) - assert isinstance(response["metadata"], list) - assert len(response["replies"]) == 1 - assert len(response["metadata"]) == 1 - assert [isinstance(reply, str) for reply in response["replies"]] - - @pytest.mark.unit - def test_generate_multiple_text_responses_with_valid_prompt_and_generation_parameters( - self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation - ): - model = "mistralai/Mistral-7B-v0.1" - generation_kwargs = {"n": 3} - stop_words = ["stop"] - streaming_callback = None - - generator = HuggingFaceTGIGenerator( - model=model, - generation_kwargs=generation_kwargs, - stop_words=stop_words, - streaming_callback=streaming_callback, - ) - generator.warm_up() - - prompt = "Hello, how are you?" - response = generator.run(prompt) - - # check kwargs passed to text_generation - # note how n was not passed to text_generation - _, kwargs = mock_text_generation.call_args - assert kwargs == {"details": True, "stop_sequences": ["stop"]} - - assert isinstance(response, dict) - assert "replies" in response - assert "metadata" in response - assert isinstance(response["replies"], list) - assert [isinstance(reply, str) for reply in response["replies"]] - - assert isinstance(response["metadata"], list) - assert len(response["replies"]) == 3 - assert len(response["metadata"]) == 3 - assert [isinstance(reply, dict) for reply in response["metadata"]] - - @pytest.mark.unit - def test_initialize_with_invalid_model(self, mock_check_valid_model): - model = "invalid_model" - generation_kwargs = {"n": 1} - stop_words = ["stop"] - streaming_callback = None - - mock_check_valid_model.side_effect = ValueError("Invalid model path or url") - - with pytest.raises(ValueError): - HuggingFaceTGIGenerator( - model=model, - generation_kwargs=generation_kwargs, - stop_words=stop_words, - streaming_callback=streaming_callback, - ) - - @pytest.mark.unit - def test_generate_text_with_stop_words(self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation): - generator = HuggingFaceTGIGenerator() - generator.warm_up() - - # Generate text response with stop words - response = generator.run("How are you?", generation_kwargs={"stop_words": ["stop", "words"]}) - - # check kwargs passed to text_generation - _, kwargs = mock_text_generation.call_args - assert kwargs == {"details": True, "stop_sequences": ["stop", "words"]} - - # Assert that the response contains the generated replies - assert "replies" in response - assert isinstance(response["replies"], list) - assert len(response["replies"]) > 0 - assert [isinstance(reply, str) for reply in response["replies"]] - - # Assert that the response contains the metadata - assert "metadata" in response - assert isinstance(response["metadata"], list) - assert len(response["metadata"]) > 0 - assert [isinstance(reply, dict) for reply in response["replies"]] - - @pytest.mark.unit - def test_generate_text_with_custom_generation_parameters( - self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation - ): - generator = HuggingFaceTGIGenerator() - generator.warm_up() - - generation_kwargs = {"temperature": 0.8, "max_new_tokens": 100} - response = generator.run("How are you?", generation_kwargs=generation_kwargs) - - # check kwargs passed to text_generation - _, kwargs = mock_text_generation.call_args - assert kwargs == {"details": True, "max_new_tokens": 100, "stop_sequences": [], "temperature": 0.8} - - # Assert that the response contains the generated replies and the right response - assert "replies" in response - assert isinstance(response["replies"], list) - assert len(response["replies"]) > 0 - assert [isinstance(reply, str) for reply in response["replies"]] - assert response["replies"][0] == "I'm fine, thanks." - - # Assert that the response contains the metadata - assert "metadata" in response - assert isinstance(response["metadata"], list) - assert len(response["metadata"]) > 0 - assert [isinstance(reply, str) for reply in response["replies"]] - - @pytest.mark.unit - def test_generate_text_with_streaming_callback( - self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation - ): - streaming_call_count = 0 - - # Define the streaming callback function - def streaming_callback_fn(chunk: StreamingChunk): - nonlocal streaming_call_count - streaming_call_count += 1 - assert isinstance(chunk, StreamingChunk) - - # Create an instance of HuggingFaceRemoteGenerator - generator = HuggingFaceTGIGenerator(streaming_callback=streaming_callback_fn) - generator.warm_up() - - # Create a fake streamed response - # Don't remove self - def mock_iter(self): - yield TextGenerationStreamResponse( - generated_text=None, token=Token(id=1, text="I'm fine, thanks.", logprob=0.0, special=False) - ) - yield TextGenerationStreamResponse( - generated_text=None, - token=Token(id=1, text="Ok bye", logprob=0.0, special=False), - details=StreamDetails(finish_reason=FinishReason.Length, generated_tokens=5), - ) - - mock_response = Mock(**{"__iter__": mock_iter}) - mock_text_generation.return_value = mock_response - - # Generate text response with streaming callback - response = generator.run("prompt") - - # check kwargs passed to text_generation - _, kwargs = mock_text_generation.call_args - assert kwargs == {"details": True, "stop_sequences": [], "stream": True} - - # Assert that the streaming callback was called twice - assert streaming_call_count == 2 - - # Assert that the response contains the generated replies - assert "replies" in response - assert isinstance(response["replies"], list) - assert len(response["replies"]) > 0 - assert [isinstance(reply, str) for reply in response["replies"]] - - # Assert that the response contains the metadata - assert "metadata" in response - assert isinstance(response["metadata"], list) - assert len(response["metadata"]) > 0 - assert [isinstance(reply, dict) for reply in response["replies"]] diff --git a/test/preview/components/generators/test_openai.py b/test/preview/components/generators/test_openai.py deleted file mode 100644 index 94a862654d..0000000000 --- a/test/preview/components/generators/test_openai.py +++ /dev/null @@ -1,343 +0,0 @@ -import os -from typing import List -from unittest.mock import patch, Mock - -import openai -import pytest - -from haystack.preview.components.generators import GPTGenerator -from haystack.preview.components.generators.utils import default_streaming_callback -from haystack.preview.dataclasses import StreamingChunk, ChatMessage - - -@pytest.fixture -def mock_chat_completion(): - """ - Mock the OpenAI API completion response and reuse it for tests - """ - with patch("openai.ChatCompletion.create", autospec=True) as mock_chat_completion_create: - # mimic the response from the OpenAI API - mock_choice = Mock() - mock_choice.index = 0 - mock_choice.finish_reason = "stop" - - mock_message = Mock() - mock_message.content = "I'm fine, thanks. How are you?" - mock_message.role = "user" - - mock_choice.message = mock_message - - mock_response = Mock() - mock_response.model = "gpt-3.5-turbo" - mock_response.usage = Mock() - mock_response.usage.items.return_value = [ - ("prompt_tokens", 57), - ("completion_tokens", 40), - ("total_tokens", 97), - ] - mock_response.choices = [mock_choice] - mock_chat_completion_create.return_value = mock_response - yield mock_chat_completion_create - - -def streaming_chunk(content: str): - """ - Mock chunks of streaming responses from the OpenAI API - """ - # mimic the chunk response from the OpenAI API - mock_choice = Mock() - mock_choice.index = 0 - mock_choice.delta.content = content - mock_choice.finish_reason = "stop" - - mock_response = Mock() - mock_response.choices = [mock_choice] - mock_response.model = "gpt-3.5-turbo" - mock_response.usage = Mock() - mock_response.usage.items.return_value = [("prompt_tokens", 57), ("completion_tokens", 40), ("total_tokens", 97)] - return mock_response - - -class TestGPTGenerator: - @pytest.mark.unit - def test_init_default(self): - component = GPTGenerator(api_key="test-api-key") - assert openai.api_key == "test-api-key" - assert component.model_name == "gpt-3.5-turbo" - assert component.streaming_callback is None - assert component.api_base_url == "https://api.openai.com/v1" - assert openai.api_base == "https://api.openai.com/v1" - assert not component.generation_kwargs - - @pytest.mark.unit - def test_init_fail_wo_api_key(self, monkeypatch): - openai.api_key = None - monkeypatch.delenv("OPENAI_API_KEY", raising=False) - with pytest.raises(ValueError, match="GPTGenerator expects an OpenAI API key"): - GPTGenerator() - - @pytest.mark.unit - def test_init_with_parameters(self): - component = GPTGenerator( - api_key="test-api-key", - model_name="gpt-4", - streaming_callback=default_streaming_callback, - api_base_url="test-base-url", - generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, - ) - assert openai.api_key == "test-api-key" - assert component.model_name == "gpt-4" - assert component.streaming_callback is default_streaming_callback - assert component.api_base_url == "test-base-url" - assert openai.api_base == "test-base-url" - assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} - - @pytest.mark.unit - def test_to_dict_default(self): - component = GPTGenerator(api_key="test-api-key") - data = component.to_dict() - assert data == { - "type": "haystack.preview.components.generators.openai.GPTGenerator", - "init_parameters": { - "model_name": "gpt-3.5-turbo", - "streaming_callback": None, - "system_prompt": None, - "api_base_url": "https://api.openai.com/v1", - "generation_kwargs": {}, - }, - } - - @pytest.mark.unit - def test_to_dict_with_parameters(self): - component = GPTGenerator( - api_key="test-api-key", - model_name="gpt-4", - streaming_callback=default_streaming_callback, - api_base_url="test-base-url", - generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, - ) - data = component.to_dict() - assert data == { - "type": "haystack.preview.components.generators.openai.GPTGenerator", - "init_parameters": { - "model_name": "gpt-4", - "system_prompt": None, - "api_base_url": "test-base-url", - "streaming_callback": "haystack.preview.components.generators.utils.default_streaming_callback", - "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, - }, - } - - @pytest.mark.unit - def test_to_dict_with_lambda_streaming_callback(self): - component = GPTGenerator( - api_key="test-api-key", - model_name="gpt-4", - streaming_callback=lambda x: x, - api_base_url="test-base-url", - generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, - ) - data = component.to_dict() - assert data == { - "type": "haystack.preview.components.generators.openai.GPTGenerator", - "init_parameters": { - "model_name": "gpt-4", - "system_prompt": None, - "api_base_url": "test-base-url", - "streaming_callback": "test_openai.", - "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, - }, - } - - @pytest.mark.unit - def test_from_dict(self, monkeypatch): - monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") - data = { - "type": "haystack.preview.components.generators.openai.GPTGenerator", - "init_parameters": { - "model_name": "gpt-4", - "system_prompt": None, - "api_base_url": "test-base-url", - "streaming_callback": "haystack.preview.components.generators.utils.default_streaming_callback", - "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, - }, - } - component = GPTGenerator.from_dict(data) - assert component.model_name == "gpt-4" - assert component.streaming_callback is default_streaming_callback - assert component.api_base_url == "test-base-url" - assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} - - @pytest.mark.unit - def test_from_dict_fail_wo_env_var(self, monkeypatch): - openai.api_key = None - monkeypatch.delenv("OPENAI_API_KEY", raising=False) - data = { - "type": "haystack.preview.components.generators.openai.GPTGenerator", - "init_parameters": { - "model_name": "gpt-4", - "api_base_url": "test-base-url", - "streaming_callback": "haystack.preview.components.generators.utils.default_streaming_callback", - "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, - }, - } - with pytest.raises(ValueError, match="GPTGenerator expects an OpenAI API key"): - GPTGenerator.from_dict(data) - - @pytest.mark.unit - def test_run(self, mock_chat_completion): - component = GPTGenerator(api_key="test-api-key") - response = component.run("What's Natural Language Processing?") - - # check that the component returns the correct ChatMessage response - assert isinstance(response, dict) - assert "replies" in response - assert isinstance(response["replies"], list) - assert len(response["replies"]) == 1 - assert [isinstance(reply, str) for reply in response["replies"]] - - @pytest.mark.unit - def test_run_with_params(self, mock_chat_completion): - component = GPTGenerator(api_key="test-api-key", generation_kwargs={"max_tokens": 10, "temperature": 0.5}) - response = component.run("What's Natural Language Processing?") - - # check that the component calls the OpenAI API with the correct parameters - _, kwargs = mock_chat_completion.call_args - assert kwargs["max_tokens"] == 10 - assert kwargs["temperature"] == 0.5 - - # check that the component returns the correct response - assert isinstance(response, dict) - assert "replies" in response - assert isinstance(response["replies"], list) - assert len(response["replies"]) == 1 - assert [isinstance(reply, str) for reply in response["replies"]] - - @pytest.mark.unit - def test_run_streaming(self, mock_chat_completion): - streaming_call_count = 0 - - # Define the streaming callback function and assert that it is called with StreamingChunk objects - def streaming_callback_fn(chunk: StreamingChunk): - nonlocal streaming_call_count - streaming_call_count += 1 - assert isinstance(chunk, StreamingChunk) - - generator = GPTGenerator(api_key="test-api-key", streaming_callback=streaming_callback_fn) - - # Create a fake streamed response - # self needed here, don't remove - def mock_iter(self): - yield streaming_chunk("Hello") - yield streaming_chunk("How are you?") - - mock_response = Mock(**{"__iter__": mock_iter}) - mock_chat_completion.return_value = mock_response - - response = generator.run("Hello there") - - # Assert that the streaming callback was called twice - assert streaming_call_count == 2 - - # Assert that the response contains the generated replies - assert "replies" in response - assert isinstance(response["replies"], list) - assert len(response["replies"]) > 0 - assert [isinstance(reply, str) for reply in response["replies"]] - - @pytest.mark.unit - def test_check_abnormal_completions(self, caplog): - component = GPTGenerator(api_key="test-api-key") - - # underlying implementation uses ChatMessage objects so we have to use them here - messages: List[ChatMessage] = [] - for i, _ in enumerate(range(4)): - message = ChatMessage.from_assistant("Hello") - metadata = {"finish_reason": "content_filter" if i % 2 == 0 else "length", "index": i} - message.metadata.update(metadata) - messages.append(message) - - for m in messages: - component._check_finish_reason(m) - - # check truncation warning - message_template = ( - "The completion for index {index} has been truncated before reaching a natural stopping point. " - "Increase the max_tokens parameter to allow for longer completions." - ) - - for index in [1, 3]: - assert caplog.records[index].message == message_template.format(index=index) - - # check content filter warning - message_template = "The completion for index {index} has been truncated due to the content filter." - for index in [0, 2]: - assert caplog.records[index].message == message_template.format(index=index) - - @pytest.mark.skipif( - not os.environ.get("OPENAI_API_KEY", None), - reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", - ) - @pytest.mark.integration - def test_live_run(self): - component = GPTGenerator(api_key=os.environ.get("OPENAI_API_KEY")) - results = component.run("What's the capital of France?") - assert len(results["replies"]) == 1 - assert len(results["metadata"]) == 1 - response: str = results["replies"][0] - assert "Paris" in response - - metadata = results["metadata"][0] - assert "gpt-3.5" in metadata["model"] - assert metadata["finish_reason"] == "stop" - - assert "usage" in metadata - assert "prompt_tokens" in metadata["usage"] and metadata["usage"]["prompt_tokens"] > 0 - assert "completion_tokens" in metadata["usage"] and metadata["usage"]["completion_tokens"] > 0 - assert "total_tokens" in metadata["usage"] and metadata["usage"]["total_tokens"] > 0 - - @pytest.mark.skipif( - not os.environ.get("OPENAI_API_KEY", None), - reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", - ) - @pytest.mark.integration - def test_live_run_wrong_model(self): - component = GPTGenerator(model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY")) - with pytest.raises(openai.InvalidRequestError, match="The model `something-obviously-wrong` does not exist"): - component.run("Whatever") - - @pytest.mark.skipif( - not os.environ.get("OPENAI_API_KEY", None), - reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", - ) - @pytest.mark.integration - def test_live_run_streaming(self): - class Callback: - def __init__(self): - self.responses = "" - self.counter = 0 - - def __call__(self, chunk: StreamingChunk) -> None: - self.counter += 1 - self.responses += chunk.content if chunk.content else "" - - callback = Callback() - component = GPTGenerator(os.environ.get("OPENAI_API_KEY"), streaming_callback=callback) - results = component.run("What's the capital of France?") - - assert len(results["replies"]) == 1 - assert len(results["metadata"]) == 1 - response: str = results["replies"][0] - assert "Paris" in response - - metadata = results["metadata"][0] - - assert "gpt-3.5" in metadata["model"] - assert metadata["finish_reason"] == "stop" - - # unfortunately, the usage is not available for streaming calls - # we keep the key in the metadata for compatibility - assert "usage" in metadata and len(metadata["usage"]) == 0 - - assert callback.counter > 1 - assert "Paris" in callback.responses diff --git a/test/preview/components/generators/test_utils.py b/test/preview/components/generators/test_utils.py deleted file mode 100644 index 9339502990..0000000000 --- a/test/preview/components/generators/test_utils.py +++ /dev/null @@ -1,37 +0,0 @@ -import pytest - -from haystack.preview.components.generators.utils import default_streaming_callback -from haystack.preview.components.generators.utils import serialize_callback_handler, deserialize_callback_handler - - -# streaming callback needs to be on module level -def streaming_callback(chunk): - pass - - -@pytest.mark.unit -def test_callback_handler_serialization(): - result = serialize_callback_handler(streaming_callback) - assert result == "test_utils.streaming_callback" - - -@pytest.mark.unit -def test_callback_handler_serialization_non_local(): - result = serialize_callback_handler(default_streaming_callback) - assert result == "haystack.preview.components.generators.utils.default_streaming_callback" - - -@pytest.mark.unit -def test_callback_handler_deserialization(): - result = serialize_callback_handler(streaming_callback) - fn = deserialize_callback_handler(result) - - assert fn is streaming_callback - - -@pytest.mark.unit -def test_callback_handler_deserialization_non_local(): - result = serialize_callback_handler(default_streaming_callback) - fn = deserialize_callback_handler(result) - - assert fn is default_streaming_callback diff --git a/test/preview/components/preprocessors/__init__.py b/test/preview/components/preprocessors/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/test/preview/components/preprocessors/test_document_cleaner.py b/test/preview/components/preprocessors/test_document_cleaner.py deleted file mode 100644 index 71f412f3f1..0000000000 --- a/test/preview/components/preprocessors/test_document_cleaner.py +++ /dev/null @@ -1,139 +0,0 @@ -import logging - -import pytest - -from haystack.preview import Document -from haystack.preview.components.preprocessors import DocumentCleaner - - -class TestDocumentCleaner: - @pytest.mark.unit - def test_init(self): - cleaner = DocumentCleaner() - assert cleaner.remove_empty_lines is True - assert cleaner.remove_extra_whitespaces is True - assert cleaner.remove_repeated_substrings is False - assert cleaner.remove_substrings is None - assert cleaner.remove_regex is None - - @pytest.mark.unit - def test_non_text_document(self, caplog): - with caplog.at_level(logging.WARNING): - cleaner = DocumentCleaner() - cleaner.run(documents=[Document()]) - assert "DocumentCleaner only cleans text documents but document.content for document ID" in caplog.text - - @pytest.mark.unit - def test_single_document(self): - with pytest.raises(TypeError, match="DocumentCleaner expects a List of Documents as input."): - cleaner = DocumentCleaner() - cleaner.run(documents=Document()) - - @pytest.mark.unit - def test_empty_list(self): - cleaner = DocumentCleaner() - result = cleaner.run(documents=[]) - assert result == {"documents": []} - - @pytest.mark.unit - def test_remove_empty_lines(self): - cleaner = DocumentCleaner(remove_extra_whitespaces=False) - result = cleaner.run( - documents=[ - Document( - content="This is a text with some words. " - "" - "There is a second sentence. " - "" - "And there is a third sentence." - ) - ] - ) - assert len(result["documents"]) == 1 - assert ( - result["documents"][0].content - == "This is a text with some words. There is a second sentence. And there is a third sentence." - ) - - @pytest.mark.unit - def test_remove_whitespaces(self): - cleaner = DocumentCleaner(remove_empty_lines=False) - result = cleaner.run( - documents=[ - Document( - content=" This is a text with some words. " - "" - "There is a second sentence. " - "" - "And there is a third sentence. " - ) - ] - ) - assert len(result["documents"]) == 1 - assert result["documents"][0].content == ( - "This is a text with some words. " "" "There is a second sentence. " "" "And there is a third sentence." - ) - - @pytest.mark.unit - def test_remove_substrings(self): - cleaner = DocumentCleaner(remove_substrings=["This", "A", "words", "🪲"]) - result = cleaner.run(documents=[Document(content="This is a text with some words.🪲")]) - assert len(result["documents"]) == 1 - assert result["documents"][0].content == " is a text with some ." - - @pytest.mark.unit - def test_remove_regex(self): - cleaner = DocumentCleaner(remove_regex=r"\s\s+") - result = cleaner.run(documents=[Document(content="This is a text with some words.")]) - assert len(result["documents"]) == 1 - assert result["documents"][0].content == "This is a text with some words." - - @pytest.mark.unit - def test_remove_repeated_substrings(self): - cleaner = DocumentCleaner( - remove_empty_lines=False, remove_extra_whitespaces=False, remove_repeated_substrings=True - ) - - text = """First Page This is a header. - Page of - 2 - 4 - Lorem ipsum dolor sit amet - This is a footer number 1 - This is footer number 2 This is a header. - Page of - 3 - 4 - Sid ut perspiciatis unde - This is a footer number 1 - This is footer number 2 This is a header. - Page of - 4 - 4 - Sed do eiusmod tempor. - This is a footer number 1 - This is footer number 2""" - - expected_text = """First Page 2 - 4 - Lorem ipsum dolor sit amet 3 - 4 - Sid ut perspiciatis unde 4 - 4 - Sed do eiusmod tempor.""" - result = cleaner.run(documents=[Document(content=text)]) - assert result["documents"][0].content == expected_text - - @pytest.mark.unit - def test_copy_metadata(self): - cleaner = DocumentCleaner() - documents = [ - Document(content="Text. ", meta={"name": "doc 0"}), - Document(content="Text. ", meta={"name": "doc 1"}), - ] - result = cleaner.run(documents=documents) - assert len(result["documents"]) == 2 - assert result["documents"][0].id != result["documents"][1].id - for doc, cleaned_doc in zip(documents, result["documents"]): - assert doc.meta == cleaned_doc.meta - assert cleaned_doc.content == "Text." diff --git a/test/preview/components/preprocessors/test_document_splitter.py b/test/preview/components/preprocessors/test_document_splitter.py deleted file mode 100644 index 4e28d1b135..0000000000 --- a/test/preview/components/preprocessors/test_document_splitter.py +++ /dev/null @@ -1,142 +0,0 @@ -import pytest - -from haystack.preview import Document -from haystack.preview.components.preprocessors import DocumentSplitter - - -class TestDocumentSplitter: - @pytest.mark.unit - def test_non_text_document(self): - with pytest.raises( - ValueError, match="DocumentSplitter only works with text documents but document.content for document ID" - ): - splitter = DocumentSplitter() - splitter.run(documents=[Document()]) - - @pytest.mark.unit - def test_single_doc(self): - with pytest.raises(TypeError, match="DocumentSplitter expects a List of Documents as input."): - splitter = DocumentSplitter() - splitter.run(documents=Document()) - - @pytest.mark.unit - def test_empty_list(self): - splitter = DocumentSplitter() - res = splitter.run(documents=[]) - assert res == {"documents": []} - - @pytest.mark.unit - def test_unsupported_split_by(self): - with pytest.raises(ValueError, match="split_by must be one of 'word', 'sentence' or 'passage'."): - DocumentSplitter(split_by="unsupported") - - @pytest.mark.unit - def test_unsupported_split_length(self): - with pytest.raises(ValueError, match="split_length must be greater than 0."): - DocumentSplitter(split_length=0) - - @pytest.mark.unit - def test_unsupported_split_overlap(self): - with pytest.raises(ValueError, match="split_overlap must be greater than or equal to 0."): - DocumentSplitter(split_overlap=-1) - - @pytest.mark.unit - def test_split_by_word(self): - splitter = DocumentSplitter(split_by="word", split_length=10) - result = splitter.run( - documents=[ - Document( - content="This is a text with some words. There is a second sentence. And there is a third sentence." - ) - ] - ) - assert len(result["documents"]) == 2 - assert result["documents"][0].content == "This is a text with some words. There is a " - assert result["documents"][1].content == "second sentence. And there is a third sentence." - - @pytest.mark.unit - def test_split_by_word_multiple_input_docs(self): - splitter = DocumentSplitter(split_by="word", split_length=10) - result = splitter.run( - documents=[ - Document( - content="This is a text with some words. There is a second sentence. And there is a third sentence." - ), - Document( - content="This is a different text with some words. There is a second sentence. And there is a third sentence. And there is a fourth sentence." - ), - ] - ) - assert len(result["documents"]) == 5 - assert result["documents"][0].content == "This is a text with some words. There is a " - assert result["documents"][1].content == "second sentence. And there is a third sentence." - assert result["documents"][2].content == "This is a different text with some words. There is " - assert result["documents"][3].content == "a second sentence. And there is a third sentence. And " - assert result["documents"][4].content == "there is a fourth sentence." - - @pytest.mark.unit - def test_split_by_sentence(self): - splitter = DocumentSplitter(split_by="sentence", split_length=1) - result = splitter.run( - documents=[ - Document( - content="This is a text with some words. There is a second sentence. And there is a third sentence." - ) - ] - ) - assert len(result["documents"]) == 3 - assert result["documents"][0].content == "This is a text with some words." - assert result["documents"][1].content == " There is a second sentence." - assert result["documents"][2].content == " And there is a third sentence." - - @pytest.mark.unit - def test_split_by_passage(self): - splitter = DocumentSplitter(split_by="passage", split_length=1) - result = splitter.run( - documents=[ - Document( - content="This is a text with some words. There is a second sentence.\n\nAnd there is a third sentence.\n\n And another passage." - ) - ] - ) - assert len(result["documents"]) == 3 - assert result["documents"][0].content == "This is a text with some words. There is a second sentence.\n\n" - assert result["documents"][1].content == "And there is a third sentence.\n\n" - assert result["documents"][2].content == " And another passage." - - @pytest.mark.unit - def test_split_by_word_with_overlap(self): - splitter = DocumentSplitter(split_by="word", split_length=10, split_overlap=2) - result = splitter.run( - documents=[ - Document( - content="This is a text with some words. There is a second sentence. And there is a third sentence." - ) - ] - ) - assert len(result["documents"]) == 2 - assert result["documents"][0].content == "This is a text with some words. There is a " - assert result["documents"][1].content == "is a second sentence. And there is a third sentence." - - @pytest.mark.unit - def test_source_id_stored_in_metadata(self): - splitter = DocumentSplitter(split_by="word", split_length=10) - doc1 = Document(content="This is a text with some words.") - doc2 = Document(content="This is a different text with some words.") - result = splitter.run(documents=[doc1, doc2]) - assert result["documents"][0].meta["source_id"] == doc1.id - assert result["documents"][1].meta["source_id"] == doc2.id - - @pytest.mark.unit - def test_copy_metadata(self): - splitter = DocumentSplitter(split_by="word", split_length=10) - documents = [ - Document(content="Text.", meta={"name": "doc 0"}), - Document(content="Text.", meta={"name": "doc 1"}), - ] - result = splitter.run(documents=documents) - assert len(result["documents"]) == 2 - assert result["documents"][0].id != result["documents"][1].id - for doc, split_doc in zip(documents, result["documents"]): - assert doc.meta.items() <= split_doc.meta.items() - assert split_doc.content == "Text." diff --git a/test/preview/components/rankers/test_metafield.py b/test/preview/components/rankers/test_metafield.py deleted file mode 100644 index b6e762c6f8..0000000000 --- a/test/preview/components/rankers/test_metafield.py +++ /dev/null @@ -1,122 +0,0 @@ -import pytest - -from haystack.preview import Document, ComponentError -from haystack.preview.components.rankers.meta_field import MetaFieldRanker - - -class TestMetaFieldRanker: - @pytest.mark.unit - def test_to_dict(self): - component = MetaFieldRanker(metadata_field="rating") - data = component.to_dict() - assert data == { - "type": "haystack.preview.components.rankers.meta_field.MetaFieldRanker", - "init_parameters": { - "metadata_field": "rating", - "weight": 1.0, - "top_k": None, - "ranking_mode": "reciprocal_rank_fusion", - }, - } - - @pytest.mark.unit - def test_to_dict_with_custom_init_parameters(self): - component = MetaFieldRanker(metadata_field="rating", weight=0.5, top_k=5, ranking_mode="linear_score") - data = component.to_dict() - assert data == { - "type": "haystack.preview.components.rankers.meta_field.MetaFieldRanker", - "init_parameters": {"metadata_field": "rating", "weight": 0.5, "top_k": 5, "ranking_mode": "linear_score"}, - } - - @pytest.mark.integration - @pytest.mark.parametrize("metafield_values, expected_first_value", [([1.3, 0.7, 2.1], 2.1), ([1, 5, 8], 8)]) - def test_run(self, metafield_values, expected_first_value): - """ - Test if the component ranks documents correctly. - """ - ranker = MetaFieldRanker(metadata_field="rating") - docs_before = [Document(content="abc", meta={"rating": value}) for value in metafield_values] - - output = ranker.run(documents=docs_before) - docs_after = output["documents"] - - assert len(docs_after) == 3 - assert docs_after[0].meta["rating"] == expected_first_value - - sorted_scores = sorted([doc.meta["rating"] for doc in docs_after], reverse=True) - assert [doc.meta["rating"] for doc in docs_after] == sorted_scores - - @pytest.mark.integration - def test_returns_empty_list_if_no_documents_are_provided(self): - ranker = MetaFieldRanker(metadata_field="rating") - output = ranker.run(documents=[]) - docs_after = output["documents"] - assert docs_after == [] - - @pytest.mark.integration - def test_raises_component_error_if_metadata_not_found(self): - ranker = MetaFieldRanker(metadata_field="rating") - docs_before = [Document(content="abc", meta={"wrong_field": 1.3})] - with pytest.raises(ComponentError): - ranker.run(documents=docs_before) - - @pytest.mark.integration - def test_raises_component_error_if_wrong_ranking_mode(self): - with pytest.raises(ValueError): - MetaFieldRanker(metadata_field="rating", ranking_mode="wrong_mode") - - @pytest.mark.integration - @pytest.mark.parametrize("score", [-1, 2, 1.3, 2.1]) - def test_raises_component_error_if_wrong_weight(self, score): - with pytest.raises(ValueError): - MetaFieldRanker(metadata_field="rating", weight=score) - - @pytest.mark.integration - def test_linear_score(self): - ranker = MetaFieldRanker(metadata_field="rating", ranking_mode="linear_score", weight=0.5) - docs_before = [ - Document(content="abc", meta={"rating": 1.3}, score=0.3), - Document(content="abc", meta={"rating": 0.7}, score=0.4), - Document(content="abc", meta={"rating": 2.1}, score=0.6), - ] - output = ranker.run(documents=docs_before) - docs_after = output["documents"] - assert docs_after[0].score == 0.8 - - @pytest.mark.integration - def test_reciprocal_rank_fusion(self): - ranker = MetaFieldRanker(metadata_field="rating", ranking_mode="reciprocal_rank_fusion", weight=0.5) - docs_before = [ - Document(content="abc", meta={"rating": 1.3}, score=0.3), - Document(content="abc", meta={"rating": 0.7}, score=0.4), - Document(content="abc", meta={"rating": 2.1}, score=0.6), - ] - output = ranker.run(documents=docs_before) - docs_after = output["documents"] - assert docs_after[0].score == 0.01626123744050767 - - @pytest.mark.integration - @pytest.mark.parametrize("score", [-1, 2, 1.3, 2.1]) - def test_linear_score_raises_warning_if_doc_wrong_score(self, score): - ranker = MetaFieldRanker(metadata_field="rating", ranking_mode="linear_score", weight=0.5) - docs_before = [ - Document(id=1, content="abc", meta={"rating": 1.3}, score=score), - Document(id=2, content="abc", meta={"rating": 0.7}, score=0.4), - Document(id=3, content="abc", meta={"rating": 2.1}, score=0.6), - ] - with pytest.warns( - UserWarning, match=rf"The score {score} for Document 1 is outside the \[0,1\] range; defaulting to 0" - ): - ranker.run(documents=docs_before) - - @pytest.mark.integration - def test_linear_score_raises_raises_warning_if_doc_without_score(self): - ranker = MetaFieldRanker(metadata_field="rating", ranking_mode="linear_score", weight=0.5) - docs_before = [ - Document(content="abc", meta={"rating": 1.3}), - Document(content="abc", meta={"rating": 0.7}), - Document(content="abc", meta={"rating": 2.1}), - ] - - with pytest.warns(UserWarning, match="The score wasn't provided; defaulting to 0."): - ranker.run(documents=docs_before) diff --git a/test/preview/components/rankers/test_transformers_similarity.py b/test/preview/components/rankers/test_transformers_similarity.py deleted file mode 100644 index 95c1d1aea7..0000000000 --- a/test/preview/components/rankers/test_transformers_similarity.py +++ /dev/null @@ -1,102 +0,0 @@ -import pytest - -from haystack.preview import Document, ComponentError -from haystack.preview.components.rankers.transformers_similarity import TransformersSimilarityRanker - - -class TestSimilarityRanker: - @pytest.mark.unit - def test_to_dict(self): - component = TransformersSimilarityRanker() - data = component.to_dict() - assert data == { - "type": "haystack.preview.components.rankers.transformers_similarity.TransformersSimilarityRanker", - "init_parameters": { - "device": "cpu", - "top_k": 10, - "token": None, - "model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2", - }, - } - - @pytest.mark.unit - def test_to_dict_with_custom_init_parameters(self): - component = TransformersSimilarityRanker( - model_name_or_path="my_model", device="cuda", token="my_token", top_k=5 - ) - data = component.to_dict() - assert data == { - "type": "haystack.preview.components.rankers.transformers_similarity.TransformersSimilarityRanker", - "init_parameters": { - "device": "cuda", - "model_name_or_path": "my_model", - "token": None, # we don't serialize valid tokens, - "top_k": 5, - }, - } - - @pytest.mark.integration - @pytest.mark.parametrize( - "query,docs_before_texts,expected_first_text", - [ - ("City in Bosnia and Herzegovina", ["Berlin", "Belgrade", "Sarajevo"], "Sarajevo"), - ("Machine learning", ["Python", "Bakery in Paris", "Tesla Giga Berlin"], "Python"), - ("Cubist movement", ["Nirvana", "Pablo Picasso", "Coffee"], "Pablo Picasso"), - ], - ) - def test_run(self, query, docs_before_texts, expected_first_text): - """ - Test if the component ranks documents correctly. - """ - ranker = TransformersSimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2") - ranker.warm_up() - docs_before = [Document(content=text) for text in docs_before_texts] - output = ranker.run(query=query, documents=docs_before) - docs_after = output["documents"] - - assert len(docs_after) == 3 - assert docs_after[0].content == expected_first_text - - sorted_scores = sorted([doc.score for doc in docs_after], reverse=True) - assert [doc.score for doc in docs_after] == sorted_scores - - # Returns an empty list if no documents are provided - @pytest.mark.integration - def test_returns_empty_list_if_no_documents_are_provided(self): - sampler = TransformersSimilarityRanker() - sampler.warm_up() - output = sampler.run(query="City in Germany", documents=[]) - assert not output["documents"] - - # Raises ComponentError if model is not warmed up - @pytest.mark.integration - def test_raises_component_error_if_model_not_warmed_up(self): - sampler = TransformersSimilarityRanker() - - with pytest.raises(ComponentError): - sampler.run(query="query", documents=[Document(content="document")]) - - @pytest.mark.integration - @pytest.mark.parametrize( - "query,docs_before_texts,expected_first_text", - [ - ("City in Bosnia and Herzegovina", ["Berlin", "Belgrade", "Sarajevo"], "Sarajevo"), - ("Machine learning", ["Python", "Bakery in Paris", "Tesla Giga Berlin"], "Python"), - ("Cubist movement", ["Nirvana", "Pablo Picasso", "Coffee"], "Pablo Picasso"), - ], - ) - def test_run_top_k(self, query, docs_before_texts, expected_first_text): - """ - Test if the component ranks documents correctly with a custom top_k. - """ - ranker = TransformersSimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2", top_k=2) - ranker.warm_up() - docs_before = [Document(content=text) for text in docs_before_texts] - output = ranker.run(query=query, documents=docs_before) - docs_after = output["documents"] - - assert len(docs_after) == 2 - assert docs_after[0].content == expected_first_text - - sorted_scores = sorted([doc.score for doc in docs_after], reverse=True) - assert [doc.score for doc in docs_after] == sorted_scores diff --git a/test/preview/components/readers/test_extractive.py b/test/preview/components/readers/test_extractive.py deleted file mode 100644 index 438922ae8d..0000000000 --- a/test/preview/components/readers/test_extractive.py +++ /dev/null @@ -1,415 +0,0 @@ -from math import ceil, exp -from typing import List -from unittest.mock import patch, Mock -import pytest - -import torch -from transformers import pipeline - -from haystack.preview.components.readers import ExtractiveReader -from haystack.preview import Document - - -@pytest.fixture -def mock_tokenizer(): - def mock_tokenize( - texts: List[str], - text_pairs: List[str], - padding: bool, - truncation: bool, - max_length: int, - return_tensors: str, - return_overflowing_tokens: bool, - stride: int, - ): - assert padding - assert truncation - assert return_tensors == "pt" - assert return_overflowing_tokens - - tokens = Mock() - - num_splits = [ceil(len(text + pair) / max_length) for text, pair in zip(texts, text_pairs)] - tokens.overflow_to_sample_mapping = [i for i, num in enumerate(num_splits) for _ in range(num)] - num_samples = sum(num_splits) - tokens.encodings = [Mock() for _ in range(num_samples)] - sequence_ids = [0] * 16 + [1] * 16 + [None] * (max_length - 32) - for encoding in tokens.encodings: - encoding.sequence_ids = sequence_ids - encoding.token_to_chars = lambda i: (i - 16, i - 15) - tokens.input_ids = torch.zeros(num_samples, max_length, dtype=torch.int) - attention_mask = torch.zeros(num_samples, max_length, dtype=torch.int) - attention_mask[:32] = 1 - tokens.attention_mask = attention_mask - return tokens - - with patch("haystack.preview.components.readers.extractive.AutoTokenizer.from_pretrained") as tokenizer: - tokenizer.return_value = mock_tokenize - yield tokenizer - - -@pytest.fixture() -def mock_reader(mock_tokenizer): - class MockModel(torch.nn.Module): - def to(self, device): - assert device == "cpu:0" - self.device_set = True - return self - - def forward(self, input_ids, attention_mask, *args, **kwargs): - assert input_ids.device == torch.device("cpu") - assert attention_mask.device == torch.device("cpu") - assert self.device_set - start = torch.zeros(input_ids.shape[:2]) - end = torch.zeros(input_ids.shape[:2]) - start[:, 27] = 1 - end[:, 31] = 1 - end[:, 32] = 1 - prediction = Mock() - prediction.start_logits = start - prediction.end_logits = end - return prediction - - with patch("haystack.preview.components.readers.extractive.AutoModelForQuestionAnswering.from_pretrained") as model: - model.return_value = MockModel() - reader = ExtractiveReader(model_name_or_path="mock-model", device="cpu:0") - reader.warm_up() - return reader - - -example_queries = ["Who is the chancellor of Germany?", "Who is the head of the department?"] -example_documents = [ - [ - Document(content="Angela Merkel was the chancellor of Germany."), - Document(content="Olaf Scholz is the chancellor of Germany"), - Document(content="Jerry is the head of the department."), - ] -] * 2 - - -@pytest.mark.unit -def test_to_dict(): - component = ExtractiveReader("my-model", token="secret-token", model_kwargs={"torch_dtype": "auto"}) - data = component.to_dict() - - assert data == { - "type": "haystack.preview.components.readers.extractive.ExtractiveReader", - "init_parameters": { - "model_name_or_path": "my-model", - "device": None, - "token": None, # don't serialize valid tokens - "top_k": 20, - "confidence_threshold": None, - "max_seq_length": 384, - "stride": 128, - "max_batch_size": None, - "answers_per_seq": None, - "no_answer": True, - "calibration_factor": 0.1, - "model_kwargs": {"torch_dtype": "auto"}, - }, - } - - -@pytest.mark.unit -def test_to_dict_empty_model_kwargs(): - component = ExtractiveReader("my-model", token="secret-token") - data = component.to_dict() - - assert data == { - "type": "haystack.preview.components.readers.extractive.ExtractiveReader", - "init_parameters": { - "model_name_or_path": "my-model", - "device": None, - "token": None, # don't serialize valid tokens - "top_k": 20, - "confidence_threshold": None, - "max_seq_length": 384, - "stride": 128, - "max_batch_size": None, - "answers_per_seq": None, - "no_answer": True, - "calibration_factor": 0.1, - "model_kwargs": {}, - }, - } - - -@pytest.mark.unit -def test_output(mock_reader: ExtractiveReader): - answers = mock_reader.run(example_queries[0], example_documents[0], top_k=3)[ - "answers" - ] # [0] Uncomment and remove first two indices when batching support is reintroduced - doc_ids = set() - no_answer_prob = 1 - for doc, answer in zip(example_documents[0], answers[:3]): - assert answer.start == 11 - assert answer.end == 16 - assert doc.content is not None - assert answer.data == doc.content[11:16] - assert answer.probability == pytest.approx(1 / (1 + exp(-2 * mock_reader.calibration_factor))) - no_answer_prob *= 1 - answer.probability - doc_ids.add(doc.id) - assert len(doc_ids) == 3 - assert answers[-1].probability == pytest.approx(no_answer_prob) - - -@pytest.mark.unit -def test_flatten_documents(mock_reader: ExtractiveReader): - queries, docs, query_ids = mock_reader._flatten_documents(example_queries, example_documents) - i = 0 - for j, query in enumerate(example_queries): - for doc in example_documents[j]: - assert queries[i] == query - assert docs[i] == doc - assert query_ids[i] == j - i += 1 - assert len(docs) == len(queries) == len(query_ids) == i - - -@pytest.mark.unit -def test_preprocess(mock_reader: ExtractiveReader): - _, _, seq_ids, _, query_ids, doc_ids = mock_reader._preprocess( - example_queries * 3, example_documents[0], 384, [1, 1, 1], 0 - ) - expected_seq_ids = torch.full((3, 384), -1, dtype=torch.int) - expected_seq_ids[:, :16] = 0 - expected_seq_ids[:, 16:32] = 1 - assert torch.equal(seq_ids, expected_seq_ids) - assert query_ids == [1, 1, 1] - assert doc_ids == [0, 1, 2] - - -def test_preprocess_splitting(mock_reader: ExtractiveReader): - _, _, seq_ids, _, query_ids, doc_ids = mock_reader._preprocess( - example_queries * 4, example_documents[0] + [Document(content="a" * 64)], 96, [1, 1, 1, 1], 0 - ) - assert seq_ids.shape[0] == 5 - assert query_ids == [1, 1, 1, 1, 1] - assert doc_ids == [0, 1, 2, 3, 3] - - -@pytest.mark.unit -def test_postprocess(mock_reader: ExtractiveReader): - start = torch.zeros((2, 8)) - start[0, 3] = 4 - start[0, 1] = 5 # test attention_mask - start[0, 4] = 3 - start[1, 2] = 1 - - end = torch.zeros((2, 8)) - end[0, 1] = 5 # test attention_mask - end[0, 2] = 4 # test that end can't be before start - end[0, 3] = 3 - end[0, 4] = 2 - end[1, :] = -10 - end[1, 4] = -1 - - sequence_ids = torch.ones((2, 8)) - attention_mask = torch.ones((2, 8)) - attention_mask[0, :2] = 0 - encoding = Mock() - encoding.token_to_chars = lambda i: (int(i), int(i) + 1) - - start_candidates, end_candidates, probs = mock_reader._postprocess( - start, end, sequence_ids, attention_mask, 3, [encoding, encoding] - ) - - assert len(start_candidates) == len(end_candidates) == len(probs) == 2 - assert len(start_candidates[0]) == len(end_candidates[0]) == len(probs[0]) == 3 - assert start_candidates[0][0] == 3 - assert end_candidates[0][0] == 4 - assert start_candidates[0][1] == 3 - assert end_candidates[0][1] == 5 - assert start_candidates[0][2] == 4 - assert end_candidates[0][2] == 5 - assert probs[0][0] == pytest.approx(1 / (1 + exp(-7 * mock_reader.calibration_factor))) - assert probs[0][1] == pytest.approx(1 / (1 + exp(-6 * mock_reader.calibration_factor))) - assert probs[0][2] == pytest.approx(1 / (1 + exp(-5 * mock_reader.calibration_factor))) - assert start_candidates[1][0] == 2 - assert end_candidates[1][0] == 5 - assert probs[1][0] == pytest.approx(1 / 2) - - -@pytest.mark.unit -def test_nest_answers(mock_reader: ExtractiveReader): - start = list(range(5)) - end = [i + 5 for i in start] - start = [start] * 6 # type: ignore - end = [end] * 6 # type: ignore - probabilities = torch.arange(5).unsqueeze(0) / 5 + torch.arange(6).unsqueeze(-1) / 25 - query_ids = [0] * 3 + [1] * 3 - document_ids = list(range(3)) * 2 - nested_answers = mock_reader._nest_answers( - start, end, probabilities, example_documents[0], example_queries, 5, 3, None, query_ids, document_ids, True # type: ignore - ) - expected_no_answers = [0.2 * 0.16 * 0.12, 0] - for query, answers, expected_no_answer, probabilities in zip( - example_queries, nested_answers, expected_no_answers, [probabilities[:3, -1], probabilities[3:, -1]] - ): - assert len(answers) == 4 - for doc, answer, probability in zip(example_documents[0], reversed(answers[:3]), probabilities): - assert answer.query == query - assert answer.document == doc - assert answer.probability == pytest.approx(probability) - no_answer = answers[-1] - assert no_answer.query == query - assert no_answer.document is None - assert no_answer.probability == pytest.approx(expected_no_answer) - - -@pytest.mark.unit -@patch("haystack.preview.components.readers.extractive.AutoTokenizer.from_pretrained") -@patch("haystack.preview.components.readers.extractive.AutoModelForQuestionAnswering.from_pretrained") -def test_warm_up_use_hf_token(mocked_automodel, mocked_autotokenizer): - reader = ExtractiveReader("deepset/roberta-base-squad2", token="fake-token") - reader.warm_up() - - mocked_automodel.assert_called_once_with("deepset/roberta-base-squad2", token="fake-token") - mocked_autotokenizer.assert_called_once_with("deepset/roberta-base-squad2", token="fake-token") - - -@pytest.mark.unit -def test_missing_token_to_chars_values(): - # See https://github.com/deepset-ai/haystack/issues/6098 - - def mock_tokenize( - texts: List[str], - text_pairs: List[str], - padding: bool, - truncation: bool, - max_length: int, - return_tensors: str, - return_overflowing_tokens: bool, - stride: int, - ): - assert padding - assert truncation - assert return_tensors == "pt" - assert return_overflowing_tokens - - tokens = Mock() - - num_splits = [ceil(len(text + pair) / max_length) for text, pair in zip(texts, text_pairs)] - tokens.overflow_to_sample_mapping = [i for i, num in enumerate(num_splits) for _ in range(num)] - num_samples = sum(num_splits) - tokens.encodings = [Mock() for _ in range(num_samples)] - sequence_ids = [0] * 16 + [1] * 16 + [None] * (max_length - 32) - for encoding in tokens.encodings: - encoding.sequence_ids = sequence_ids - encoding.token_to_chars = lambda i: None - tokens.input_ids = torch.zeros(num_samples, max_length, dtype=torch.int) - attention_mask = torch.zeros(num_samples, max_length, dtype=torch.int) - attention_mask[:32] = 1 - tokens.attention_mask = attention_mask - return tokens - - class MockModel(torch.nn.Module): - def to(self, device): - assert device == "cpu:0" - self.device_set = True - return self - - def forward(self, input_ids, attention_mask, *args, **kwargs): - assert input_ids.device == torch.device("cpu") - assert attention_mask.device == torch.device("cpu") - assert self.device_set - start = torch.zeros(input_ids.shape[:2]) - end = torch.zeros(input_ids.shape[:2]) - start[:, 27] = 1 - end[:, 31] = 1 - end[:, 32] = 1 - prediction = Mock() - prediction.start_logits = start - prediction.end_logits = end - return prediction - - with patch("haystack.preview.components.readers.extractive.AutoTokenizer.from_pretrained") as tokenizer, patch( - "haystack.preview.components.readers.extractive.AutoModelForQuestionAnswering.from_pretrained" - ) as model: - tokenizer.return_value = mock_tokenize - model.return_value = MockModel() - reader = ExtractiveReader(model_name_or_path="mock-model", device="cpu:0") - reader.warm_up() - - answers = reader.run(example_queries[0], example_documents[0], top_k=3)[ - "answers" - ] # [0] Uncomment and remove first two indices when batching support is reintroduced - for doc, answer in zip(example_documents[0], answers[:3]): - assert answer.start is None - assert answer.end is None - assert doc.content is not None - assert answer.data == doc.content - - -@pytest.mark.integration -def test_t5(): - reader = ExtractiveReader("TARUNBHATT/flan-t5-small-finetuned-squad") - reader.warm_up() - answers = reader.run(example_queries[0], example_documents[0], top_k=2)[ - "answers" - ] # remove indices when batching support is reintroduced - assert answers[0].data == "Angela Merkel" - assert answers[0].probability == pytest.approx(0.7764519453048706) - assert answers[1].data == "Olaf Scholz" - assert answers[1].probability == pytest.approx(0.7703777551651001) - assert answers[2].data is None - assert answers[2].probability == pytest.approx(0.051331606147570596) - # Uncomment assertions below when batching is reintroduced - # assert answers[0][2].probability == pytest.approx(0.051331606147570596) - # assert answers[1][0].data == "Jerry" - # assert answers[1][0].probability == pytest.approx(0.7413333654403687) - # assert answers[1][1].data == "Olaf Scholz" - # assert answers[1][1].probability == pytest.approx(0.7266613841056824) - # assert answers[1][2].data is None - # assert answers[1][2].probability == pytest.approx(0.0707035798685709) - - -@pytest.mark.integration -def test_roberta(): - reader = ExtractiveReader("deepset/tinyroberta-squad2") - reader.warm_up() - answers = reader.run(example_queries[0], example_documents[0], top_k=2)[ - "answers" - ] # remove indices when batching is reintroduced - assert answers[0].data == "Olaf Scholz" - assert answers[0].probability == pytest.approx(0.8614975214004517) - assert answers[1].data == "Angela Merkel" - assert answers[1].probability == pytest.approx(0.857952892780304) - assert answers[2].data is None - assert answers[2].probability == pytest.approx(0.019673851661650588) - # uncomment assertions below when there is batching in v2 - # assert answers[0][0].data == "Olaf Scholz" - # assert answers[0][0].probability == pytest.approx(0.8614975214004517) - # assert answers[0][1].data == "Angela Merkel" - # assert answers[0][1].probability == pytest.approx(0.857952892780304) - # assert answers[0][2].data is None - # assert answers[0][2].probability == pytest.approx(0.0196738764278237) - # assert answers[1][0].data == "Jerry" - # assert answers[1][0].probability == pytest.approx(0.7048940658569336) - # assert answers[1][1].data == "Olaf Scholz" - # assert answers[1][1].probability == pytest.approx(0.6604189872741699) - # assert answers[1][2].data is None - # assert answers[1][2].probability == pytest.approx(0.1002123719777046) - - -@pytest.mark.integration -def test_matches_hf_pipeline(): - reader = ExtractiveReader("deepset/tinyroberta-squad2", device="cpu") - reader.warm_up() - answers = reader.run(example_queries[0], [[example_documents[0][0]]][0], top_k=20, no_answer=False)[ - "answers" - ] # [0] Remove first two indices when batching support is reintroduced - pipe = pipeline("question-answering", model=reader.model, tokenizer=reader.tokenizer, align_to_words=False) - answers_hf = pipe( - question=example_queries[0], - context=example_documents[0][0].content, - max_answer_len=1_000, - handle_impossible_answer=False, - top_k=20, - ) # We need to disable HF postprocessing features to make the results comparable. This is related to https://github.com/huggingface/transformers/issues/26286 - assert len(answers) == len(answers_hf) == 20 - for answer, answer_hf in zip(answers, answers_hf): - assert answer.start == answer_hf["start"] - assert answer.end == answer_hf["end"] - assert answer.data == answer_hf["answer"] diff --git a/test/preview/components/retrievers/__init__.py b/test/preview/components/retrievers/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/test/preview/components/retrievers/test_in_memory_bm25_retriever.py b/test/preview/components/retrievers/test_in_memory_bm25_retriever.py deleted file mode 100644 index 2a84f3bace..0000000000 --- a/test/preview/components/retrievers/test_in_memory_bm25_retriever.py +++ /dev/null @@ -1,185 +0,0 @@ -from typing import Dict, Any - -import pytest - -from haystack.preview import Pipeline, DeserializationError -from haystack.preview.testing.factory import document_store_class -from haystack.preview.components.retrievers.in_memory_bm25_retriever import InMemoryBM25Retriever -from haystack.preview.dataclasses import Document -from haystack.preview.document_stores import InMemoryDocumentStore - - -@pytest.fixture() -def mock_docs(): - return [ - Document(content="Javascript is a popular programming language"), - Document(content="Java is a popular programming language"), - Document(content="Python is a popular programming language"), - Document(content="Ruby is a popular programming language"), - Document(content="PHP is a popular programming language"), - ] - - -class TestMemoryBM25Retriever: - @pytest.mark.unit - def test_init_default(self): - retriever = InMemoryBM25Retriever(InMemoryDocumentStore()) - assert retriever.filters is None - assert retriever.top_k == 10 - assert retriever.scale_score is False - - @pytest.mark.unit - def test_init_with_parameters(self): - retriever = InMemoryBM25Retriever( - InMemoryDocumentStore(), filters={"name": "test.txt"}, top_k=5, scale_score=True - ) - assert retriever.filters == {"name": "test.txt"} - assert retriever.top_k == 5 - assert retriever.scale_score - - @pytest.mark.unit - def test_init_with_invalid_top_k_parameter(self): - with pytest.raises(ValueError): - InMemoryBM25Retriever(InMemoryDocumentStore(), top_k=-2) - - @pytest.mark.unit - def test_to_dict(self): - MyFakeStore = document_store_class("MyFakeStore", bases=(InMemoryDocumentStore,)) - document_store = MyFakeStore() - document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}} - component = InMemoryBM25Retriever(document_store=document_store) - - data = component.to_dict() - assert data == { - "type": "haystack.preview.components.retrievers.in_memory_bm25_retriever.InMemoryBM25Retriever", - "init_parameters": { - "document_store": {"type": "MyFakeStore", "init_parameters": {}}, - "filters": None, - "top_k": 10, - "scale_score": False, - }, - } - - @pytest.mark.unit - def test_to_dict_with_custom_init_parameters(self): - MyFakeStore = document_store_class("MyFakeStore", bases=(InMemoryDocumentStore,)) - document_store = MyFakeStore() - document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}} - component = InMemoryBM25Retriever( - document_store=document_store, filters={"name": "test.txt"}, top_k=5, scale_score=True - ) - data = component.to_dict() - assert data == { - "type": "haystack.preview.components.retrievers.in_memory_bm25_retriever.InMemoryBM25Retriever", - "init_parameters": { - "document_store": {"type": "MyFakeStore", "init_parameters": {}}, - "filters": {"name": "test.txt"}, - "top_k": 5, - "scale_score": True, - }, - } - - @pytest.mark.unit - def test_from_dict(self): - document_store_class("MyFakeStore", bases=(InMemoryDocumentStore,)) - data = { - "type": "haystack.preview.components.retrievers.in_memory_bm25_retriever.InMemoryBM25Retriever", - "init_parameters": { - "document_store": {"type": "haystack.preview.testing.factory.MyFakeStore", "init_parameters": {}}, - "filters": {"name": "test.txt"}, - "top_k": 5, - }, - } - component = InMemoryBM25Retriever.from_dict(data) - assert isinstance(component.document_store, InMemoryDocumentStore) - assert component.filters == {"name": "test.txt"} - assert component.top_k == 5 - assert component.scale_score is False - - @pytest.mark.unit - def test_from_dict_without_docstore(self): - data = {"type": "InMemoryBM25Retriever", "init_parameters": {}} - with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"): - InMemoryBM25Retriever.from_dict(data) - - @pytest.mark.unit - def test_from_dict_without_docstore_type(self): - data = {"type": "InMemoryBM25Retriever", "init_parameters": {"document_store": {"init_parameters": {}}}} - with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"): - InMemoryBM25Retriever.from_dict(data) - - @pytest.mark.unit - def test_from_dict_nonexisting_docstore(self): - data = { - "type": "haystack.preview.components.retrievers.in_memory_bm25_retriever.InMemoryBM25Retriever", - "init_parameters": {"document_store": {"type": "NonexistingDocstore", "init_parameters": {}}}, - } - with pytest.raises(DeserializationError, match="DocumentStore type 'NonexistingDocstore' not found"): - InMemoryBM25Retriever.from_dict(data) - - @pytest.mark.unit - def test_retriever_valid_run(self, mock_docs): - top_k = 5 - ds = InMemoryDocumentStore() - ds.write_documents(mock_docs) - - retriever = InMemoryBM25Retriever(ds, top_k=top_k) - result = retriever.run(query="PHP") - - assert "documents" in result - assert len(result["documents"]) == top_k - assert result["documents"][0].content == "PHP is a popular programming language" - - @pytest.mark.unit - def test_invalid_run_wrong_store_type(self): - SomeOtherDocumentStore = document_store_class("SomeOtherDocumentStore") - with pytest.raises(ValueError, match="document_store must be an instance of InMemoryDocumentStore"): - InMemoryBM25Retriever(SomeOtherDocumentStore()) - - @pytest.mark.integration - @pytest.mark.parametrize( - "query, query_result", - [ - ("Javascript", "Javascript is a popular programming language"), - ("Java", "Java is a popular programming language"), - ], - ) - def test_run_with_pipeline(self, mock_docs, query: str, query_result: str): - ds = InMemoryDocumentStore() - ds.write_documents(mock_docs) - retriever = InMemoryBM25Retriever(ds) - - pipeline = Pipeline() - pipeline.add_component("retriever", retriever) - result: Dict[str, Any] = pipeline.run(data={"retriever": {"query": query}}) - - assert result - assert "retriever" in result - results_docs = result["retriever"]["documents"] - assert results_docs - assert results_docs[0].content == query_result - - @pytest.mark.integration - @pytest.mark.parametrize( - "query, query_result, top_k", - [ - ("Javascript", "Javascript is a popular programming language", 1), - ("Java", "Java is a popular programming language", 2), - ("Ruby", "Ruby is a popular programming language", 3), - ], - ) - def test_run_with_pipeline_and_top_k(self, mock_docs, query: str, query_result: str, top_k: int): - ds = InMemoryDocumentStore() - ds.write_documents(mock_docs) - retriever = InMemoryBM25Retriever(ds) - - pipeline = Pipeline() - pipeline.add_component("retriever", retriever) - result: Dict[str, Any] = pipeline.run(data={"retriever": {"query": query, "top_k": top_k}}) - - assert result - assert "retriever" in result - results_docs = result["retriever"]["documents"] - assert results_docs - assert len(results_docs) == top_k - assert results_docs[0].content == query_result diff --git a/test/preview/components/retrievers/test_in_memory_embedding_retriever.py b/test/preview/components/retrievers/test_in_memory_embedding_retriever.py deleted file mode 100644 index 6c03e6d621..0000000000 --- a/test/preview/components/retrievers/test_in_memory_embedding_retriever.py +++ /dev/null @@ -1,169 +0,0 @@ -from typing import Dict, Any - -import pytest -import numpy as np - -from haystack.preview import Pipeline, DeserializationError -from haystack.preview.testing.factory import document_store_class -from haystack.preview.components.retrievers.in_memory_embedding_retriever import InMemoryEmbeddingRetriever -from haystack.preview.dataclasses import Document -from haystack.preview.document_stores import InMemoryDocumentStore - - -class TestMemoryEmbeddingRetriever: - @pytest.mark.unit - def test_init_default(self): - retriever = InMemoryEmbeddingRetriever(InMemoryDocumentStore()) - assert retriever.filters is None - assert retriever.top_k == 10 - assert retriever.scale_score is False - - @pytest.mark.unit - def test_init_with_parameters(self): - retriever = InMemoryEmbeddingRetriever( - InMemoryDocumentStore(), filters={"name": "test.txt"}, top_k=5, scale_score=True - ) - assert retriever.filters == {"name": "test.txt"} - assert retriever.top_k == 5 - assert retriever.scale_score - - @pytest.mark.unit - def test_init_with_invalid_top_k_parameter(self): - with pytest.raises(ValueError): - InMemoryEmbeddingRetriever(InMemoryDocumentStore(), top_k=-2) - - @pytest.mark.unit - def test_to_dict(self): - MyFakeStore = document_store_class("MyFakeStore", bases=(InMemoryDocumentStore,)) - document_store = MyFakeStore() - document_store.to_dict = lambda: {"type": "test_module.MyFakeStore", "init_parameters": {}} - component = InMemoryEmbeddingRetriever(document_store=document_store) - - data = component.to_dict() - assert data == { - "type": "haystack.preview.components.retrievers.in_memory_embedding_retriever.InMemoryEmbeddingRetriever", - "init_parameters": { - "document_store": {"type": "test_module.MyFakeStore", "init_parameters": {}}, - "filters": None, - "top_k": 10, - "scale_score": False, - "return_embedding": False, - }, - } - - @pytest.mark.unit - def test_to_dict_with_custom_init_parameters(self): - MyFakeStore = document_store_class("MyFakeStore", bases=(InMemoryDocumentStore,)) - document_store = MyFakeStore() - document_store.to_dict = lambda: {"type": "test_module.MyFakeStore", "init_parameters": {}} - component = InMemoryEmbeddingRetriever( - document_store=document_store, - filters={"name": "test.txt"}, - top_k=5, - scale_score=True, - return_embedding=True, - ) - data = component.to_dict() - assert data == { - "type": "haystack.preview.components.retrievers.in_memory_embedding_retriever.InMemoryEmbeddingRetriever", - "init_parameters": { - "document_store": {"type": "test_module.MyFakeStore", "init_parameters": {}}, - "filters": {"name": "test.txt"}, - "top_k": 5, - "scale_score": True, - "return_embedding": True, - }, - } - - @pytest.mark.unit - def test_from_dict(self): - document_store_class("MyFakeStore", bases=(InMemoryDocumentStore,)) - data = { - "type": "haystack.preview.components.retrievers.in_memory_embedding_retriever.InMemoryEmbeddingRetriever", - "init_parameters": { - "document_store": {"type": "haystack.preview.testing.factory.MyFakeStore", "init_parameters": {}}, - "filters": {"name": "test.txt"}, - "top_k": 5, - }, - } - component = InMemoryEmbeddingRetriever.from_dict(data) - assert isinstance(component.document_store, InMemoryDocumentStore) - assert component.filters == {"name": "test.txt"} - assert component.top_k == 5 - assert component.scale_score is False - - @pytest.mark.unit - def test_from_dict_without_docstore(self): - data = { - "type": "haystack.preview.components.retrievers.in_memory_embedding_retriever.InMemoryEmbeddingRetriever", - "init_parameters": {}, - } - with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"): - InMemoryEmbeddingRetriever.from_dict(data) - - @pytest.mark.unit - def test_from_dict_without_docstore_type(self): - data = { - "type": "haystack.preview.components.retrievers.in_memory_embedding_retriever.InMemoryEmbeddingRetriever", - "init_parameters": {"document_store": {"init_parameters": {}}}, - } - with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"): - InMemoryEmbeddingRetriever.from_dict(data) - - @pytest.mark.unit - def test_from_dict_nonexisting_docstore(self): - data = { - "type": "haystack.preview.components.retrievers.in_memory_embedding_retriever.InMemoryEmbeddingRetriever", - "init_parameters": {"document_store": {"type": "NonexistingDocstore", "init_parameters": {}}}, - } - with pytest.raises(DeserializationError, match="DocumentStore type 'NonexistingDocstore' not found"): - InMemoryEmbeddingRetriever.from_dict(data) - - @pytest.mark.unit - def test_valid_run(self): - top_k = 3 - ds = InMemoryDocumentStore(embedding_similarity_function="cosine") - docs = [ - Document(content="my document", embedding=[0.1, 0.2, 0.3, 0.4]), - Document(content="another document", embedding=[1.0, 1.0, 1.0, 1.0]), - Document(content="third document", embedding=[0.5, 0.7, 0.5, 0.7]), - ] - ds.write_documents(docs) - - retriever = InMemoryEmbeddingRetriever(ds, top_k=top_k) - result = retriever.run(query_embedding=[0.1, 0.1, 0.1, 0.1], return_embedding=True) - - assert "documents" in result - assert len(result["documents"]) == top_k - assert np.array_equal(result["documents"][0].embedding, [1.0, 1.0, 1.0, 1.0]) - - @pytest.mark.unit - def test_invalid_run_wrong_store_type(self): - SomeOtherDocumentStore = document_store_class("SomeOtherDocumentStore") - with pytest.raises(ValueError, match="document_store must be an instance of InMemoryDocumentStore"): - InMemoryEmbeddingRetriever(SomeOtherDocumentStore()) - - @pytest.mark.integration - def test_run_with_pipeline(self): - ds = InMemoryDocumentStore(embedding_similarity_function="cosine") - top_k = 2 - docs = [ - Document(content="my document", embedding=[0.1, 0.2, 0.3, 0.4]), - Document(content="another document", embedding=[1.0, 1.0, 1.0, 1.0]), - Document(content="third document", embedding=[0.5, 0.7, 0.5, 0.7]), - ] - ds.write_documents(docs) - retriever = InMemoryEmbeddingRetriever(ds, top_k=top_k) - - pipeline = Pipeline() - pipeline.add_component("retriever", retriever) - result: Dict[str, Any] = pipeline.run( - data={"retriever": {"query_embedding": [0.1, 0.1, 0.1, 0.1], "return_embedding": True}} - ) - - assert result - assert "retriever" in result - results_docs = result["retriever"]["documents"] - assert results_docs - assert len(results_docs) == top_k - assert np.array_equal(results_docs[0].embedding, [1.0, 1.0, 1.0, 1.0]) diff --git a/test/preview/components/routers/__init__.py b/test/preview/components/routers/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/test/preview/components/routers/test_conditional_router.py b/test/preview/components/routers/test_conditional_router.py deleted file mode 100644 index 501fb530d1..0000000000 --- a/test/preview/components/routers/test_conditional_router.py +++ /dev/null @@ -1,324 +0,0 @@ -import copy -import typing -from typing import List, Dict -from unittest import mock - -import pytest - -from haystack.preview.components.routers import ConditionalRouter -from haystack.preview.components.routers.conditional_router import ( - NoRouteSelectedException, - serialize_type, - deserialize_type, -) -from haystack.preview.dataclasses import ChatMessage - - -class TestRouter: - @pytest.fixture - def routes(self): - return [ - {"condition": "{{streams|length < 2}}", "output": "{{query}}", "output_type": str, "output_name": "query"}, - { - "condition": "{{streams|length >= 2}}", - "output": "{{streams}}", - "output_type": List[int], - "output_name": "streams", - }, - ] - - @pytest.fixture - def router(self, routes): - return ConditionalRouter(routes) - - def test_missing_mandatory_fields(self): - """ - Router raises a ValueError if each route does not contain 'condition', 'output', and 'output_type' keys - """ - routes = [ - {"condition": "{{streams|length < 2}}", "output": "{{query}}"}, - {"condition": "{{streams|length < 2}}", "output_type": str}, - ] - with pytest.raises(ValueError): - ConditionalRouter(routes) - - def test_invalid_condition_field(self): - """ - ConditionalRouter init raises a ValueError if one of the routes contains invalid condition - """ - # invalid condition field - routes = [{"condition": "{{streams|length < 2", "output": "query", "output_type": str, "output_name": "test"}] - with pytest.raises(ValueError, match="Invalid template"): - ConditionalRouter(routes) - - def test_no_vars_in_output_route_but_with_output_name(self): - """ - Router can't accept a route with no variables used in the output field - """ - routes = [ - { - "condition": "{{streams|length > 2}}", - "output": "This is a constant", - "output_name": "enough_streams", - "output_type": str, - } - ] - router = ConditionalRouter(routes) - kwargs = {"streams": [1, 2, 3], "query": "Haystack"} - result = router.run(**kwargs) - assert result == {"enough_streams": "This is a constant"} - - def test_mandatory_and_optional_fields_with_extra_fields(self): - """ - Router accepts a list of routes with mandatory and optional fields but not if some new field is added - """ - - routes = [ - { - "condition": "{{streams|length < 2}}", - "output": "{{query}}", - "output_type": str, - "output_name": "test", - "bla": "bla", - }, - {"condition": "{{streams|length < 2}}", "output": "{{query}}", "output_type": str}, - ] - - with pytest.raises(ValueError): - ConditionalRouter(routes) - - def test_router_initialized(self, routes): - router = ConditionalRouter(routes) - - assert router.routes == routes - assert set(router.__canals_input__.keys()) == {"query", "streams"} - assert set(router.__canals_output__.keys()) == {"query", "streams"} - - def test_router_evaluate_condition_expressions(self, router): - # first route should be selected - kwargs = {"streams": [1, 2, 3], "query": "test"} - result = router.run(**kwargs) - assert result == {"streams": [1, 2, 3]} - - # second route should be selected - kwargs = {"streams": [1], "query": "test"} - result = router.run(**kwargs) - assert result == {"query": "test"} - - def test_router_evaluate_condition_expressions_using_output_slot(self): - routes = [ - { - "condition": "{{streams|length > 2}}", - "output": "{{streams}}", - "output_name": "enough_streams", - "output_type": List[int], - }, - { - "condition": "{{streams|length <= 2}}", - "output": "{{streams}}", - "output_name": "insufficient_streams", - "output_type": List[int], - }, - ] - router = ConditionalRouter(routes) - # enough_streams output slot will be selected with [1, 2, 3] list being outputted - kwargs = {"streams": [1, 2, 3], "query": "Haystack"} - result = router.run(**kwargs) - assert result == {"enough_streams": [1, 2, 3]} - - def test_complex_condition(self): - routes = [ - { - "condition": "{{messages[-1].metadata.finish_reason == 'function_call'}}", - "output": "{{streams}}", - "output_type": List[int], - "output_name": "streams", - }, - { - "condition": "{{True}}", - "output": "{{query}}", - "output_type": str, - "output_name": "query", - }, # catch-all condition - ] - router = ConditionalRouter(routes) - message = mock.MagicMock() - message.metadata.finish_reason = "function_call" - result = router.run(messages=[message], streams=[1, 2, 3], query="my query") - assert result == {"streams": [1, 2, 3]} - - def test_router_no_route(self, router): - # should raise an exception - router = ConditionalRouter( - [ - { - "condition": "{{streams|length < 2}}", - "output": "{{query}}", - "output_type": str, - "output_name": "query", - }, - { - "condition": "{{streams|length >= 5}}", - "output": "{{streams}}", - "output_type": List[int], - "output_name": "streams", - }, - ] - ) - - kwargs = {"streams": [1, 2, 3], "query": "test"} - with pytest.raises(NoRouteSelectedException): - router.run(**kwargs) - - def test_router_raises_value_error_if_route_not_dictionary(self): - """ - Router raises a ValueError if each route is not a dictionary - """ - routes = [ - {"condition": "{{streams|length < 2}}", "output": "{{query}}", "output_type": str, "output_name": "query"}, - ["{{streams|length >= 2}}", "streams", List[int]], - ] - - with pytest.raises(ValueError): - ConditionalRouter(routes) - - def test_router_raises_value_error_if_route_missing_keys(self): - """ - Router raises a ValueError if each route does not contain 'condition', 'output', and 'output_type' keys - """ - routes = [ - {"condition": "{{streams|length < 2}}", "output": "{{query}}"}, - {"condition": "{{streams|length < 2}}", "output_type": str}, - ] - - with pytest.raises(ValueError): - ConditionalRouter(routes) - - def test_output_type_serialization(self): - assert serialize_type(str) == "str" - assert serialize_type(List[int]) == "typing.List[int]" - assert serialize_type(List[Dict[str, int]]) == "typing.List[typing.Dict[str, int]]" - assert serialize_type(ChatMessage) == "haystack.preview.dataclasses.chat_message.ChatMessage" - assert serialize_type(typing.List[Dict[str, int]]) == "typing.List[typing.Dict[str, int]]" - assert serialize_type(List[ChatMessage]) == "typing.List[haystack.preview.dataclasses.chat_message.ChatMessage]" - assert ( - serialize_type(typing.Dict[int, ChatMessage]) - == "typing.Dict[int, haystack.preview.dataclasses.chat_message.ChatMessage]" - ) - assert serialize_type(int) == "int" - assert serialize_type(ChatMessage.from_user("ciao")) == "haystack.preview.dataclasses.chat_message.ChatMessage" - - def test_output_type_deserialization(self): - assert deserialize_type("str") == str - assert deserialize_type("typing.List[int]") == typing.List[int] - assert deserialize_type("typing.List[typing.Dict[str, int]]") == typing.List[Dict[str, int]] - assert deserialize_type("typing.Dict[str, int]") == Dict[str, int] - assert deserialize_type("typing.Dict[str, typing.List[int]]") == Dict[str, List[int]] - assert deserialize_type("typing.List[typing.Dict[str, typing.List[int]]]") == List[Dict[str, List[int]]] - assert ( - deserialize_type("typing.List[haystack.preview.dataclasses.chat_message.ChatMessage]") - == typing.List[ChatMessage] - ) - assert ( - deserialize_type("typing.Dict[int, haystack.preview.dataclasses.chat_message.ChatMessage]") - == typing.Dict[int, ChatMessage] - ) - assert deserialize_type("haystack.preview.dataclasses.chat_message.ChatMessage") == ChatMessage - assert deserialize_type("int") == int - - def test_router_de_serialization(self): - routes = [ - {"condition": "{{streams|length < 2}}", "output": "{{query}}", "output_type": str, "output_name": "query"}, - { - "condition": "{{streams|length >= 2}}", - "output": "{{streams}}", - "output_type": List[int], - "output_name": "streams", - }, - ] - router = ConditionalRouter(routes) - router_dict = router.to_dict() - - # assert that the router dict is correct, with all keys and values being strings - for route in router_dict["init_parameters"]["routes"]: - for key in route.keys(): - assert isinstance(key, str) - assert isinstance(route[key], str) - - new_router = ConditionalRouter.from_dict(router_dict) - assert router.routes == new_router.routes - - # now use both routers with the same input - kwargs = {"streams": [1, 2, 3], "query": "Haystack"} - result1 = router.run(**kwargs) - result2 = new_router.run(**kwargs) - - # check that the result is the same and correct - assert result1 == result2 and result1 == {"streams": [1, 2, 3]} - - def test_router_de_serialization_user_type(self): - routes = [ - { - "condition": "{{streams|length < 2}}", - "output": "{{message}}", - "output_type": ChatMessage, - "output_name": "message", - }, - { - "condition": "{{streams|length >= 2}}", - "output": "{{streams}}", - "output_type": List[int], - "output_name": "streams", - }, - ] - router = ConditionalRouter(routes) - router_dict = router.to_dict() - - # assert that the router dict is correct, with all keys and values being strings - for route in router_dict["init_parameters"]["routes"]: - for key in route.keys(): - assert isinstance(key, str) - assert isinstance(route[key], str) - - # check that the output_type is a string and a proper class name - assert ( - router_dict["init_parameters"]["routes"][0]["output_type"] - == "haystack.preview.dataclasses.chat_message.ChatMessage" - ) - - # deserialize the router - new_router = ConditionalRouter.from_dict(router_dict) - - # check that the output_type is the right class - assert new_router.routes[0]["output_type"] == ChatMessage - assert router.routes == new_router.routes - - # now use both routers to run the same message - message = ChatMessage.from_user("ciao") - kwargs = {"streams": [1], "message": message} - result1 = router.run(**kwargs) - result2 = new_router.run(**kwargs) - - # check that the result is the same and correct - assert result1 == result2 and result1["message"].content == message.content - - def test_router_serialization_idempotence(self): - routes = [ - { - "condition": "{{streams|length < 2}}", - "output": "{{message}}", - "output_type": ChatMessage, - "output_name": "message", - }, - { - "condition": "{{streams|length >= 2}}", - "output": "{{streams}}", - "output_type": List[int], - "output_name": "streams", - }, - ] - router = ConditionalRouter(routes) - # invoke to_dict twice and check that the result is the same - router_dict_first_invocation = copy.deepcopy(router.to_dict()) - router_dict_second_invocation = router.to_dict() - assert router_dict_first_invocation == router_dict_second_invocation diff --git a/test/preview/components/routers/test_document_joiner.py b/test/preview/components/routers/test_document_joiner.py deleted file mode 100644 index 9b4ab7bf2a..0000000000 --- a/test/preview/components/routers/test_document_joiner.py +++ /dev/null @@ -1,140 +0,0 @@ -import logging - -import pytest - -from haystack.preview import Document -from haystack.preview.components.routers.document_joiner import DocumentJoiner - - -class TestDocumentJoiner: - @pytest.mark.unit - def test_init(self): - joiner = DocumentJoiner() - assert joiner.join_mode == "concatenate" - assert joiner.weights is None - assert joiner.top_k is None - assert joiner.sort_by_score - - @pytest.mark.unit - def test_init_with_custom_parameters(self): - joiner = DocumentJoiner(join_mode="merge", weights=[0.4, 0.6], top_k=5, sort_by_score=False) - assert joiner.join_mode == "merge" - assert joiner.weights == [0.4, 0.6] - assert joiner.top_k == 5 - assert not joiner.sort_by_score - - @pytest.mark.unit - def test_empty_list(self): - joiner = DocumentJoiner() - result = joiner.run([]) - assert result == {"documents": []} - - @pytest.mark.unit - def test_list_of_empty_lists(self): - joiner = DocumentJoiner() - result = joiner.run([[], []]) - assert result == {"documents": []} - - @pytest.mark.unit - def test_list_with_one_empty_list(self): - joiner = DocumentJoiner() - documents = [Document(content="a"), Document(content="b"), Document(content="c")] - result = joiner.run([[], documents]) - assert result == {"documents": documents} - - @pytest.mark.unit - def test_unsupported_join_mode(self): - with pytest.raises(ValueError, match="DocumentJoiner component does not support 'unsupported_mode' join_mode."): - DocumentJoiner(join_mode="unsupported_mode") - - @pytest.mark.unit - def test_run_with_concatenate_join_mode_and_top_k(self): - joiner = DocumentJoiner(top_k=6) - documents_1 = [Document(content="a"), Document(content="b"), Document(content="c")] - documents_2 = [ - Document(content="d"), - Document(content="e"), - Document(content="f", meta={"key": "value"}), - Document(content="g"), - ] - output = joiner.run([documents_1, documents_2]) - assert len(output["documents"]) == 6 - assert sorted(documents_1 + documents_2[:-1], key=lambda d: d.id) == sorted( - output["documents"], key=lambda d: d.id - ) - - @pytest.mark.unit - def test_run_with_concatenate_join_mode_and_duplicate_documents(self): - joiner = DocumentJoiner() - documents_1 = [Document(content="a", score=0.3), Document(content="b"), Document(content="c")] - documents_2 = [ - Document(content="a", score=0.2), - Document(content="a"), - Document(content="f", meta={"key": "value"}), - ] - output = joiner.run([documents_1, documents_2]) - assert len(output["documents"]) == 4 - assert sorted(documents_1 + [documents_2[-1]], key=lambda d: d.id) == sorted( - output["documents"], key=lambda d: d.id - ) - - @pytest.mark.unit - def test_run_with_merge_join_mode(self): - joiner = DocumentJoiner(join_mode="merge", weights=[1.5, 0.5]) - documents_1 = [Document(content="a", score=1.0), Document(content="b", score=2.0)] - documents_2 = [ - Document(content="a", score=0.5), - Document(content="b", score=3.0), - Document(content="f", score=4.0, meta={"key": "value"}), - ] - output = joiner.run([documents_1, documents_2]) - assert len(output["documents"]) == 3 - expected_document_ids = [ - doc.id - for doc in [ - Document(content="a", score=1.25), - Document(content="b", score=2.25), - Document(content="f", score=4.0, meta={"key": "value"}), - ] - ] - assert all(doc.id in expected_document_ids for doc in output["documents"]) - - @pytest.mark.unit - def test_run_with_reciprocal_rank_fusion_join_mode(self): - joiner = DocumentJoiner(join_mode="reciprocal_rank_fusion") - documents_1 = [Document(content="a"), Document(content="b"), Document(content="c")] - documents_2 = [ - Document(content="b", score=1000.0), - Document(content="c"), - Document(content="a"), - Document(content="f", meta={"key": "value"}), - ] - output = joiner.run([documents_1, documents_2]) - assert len(output["documents"]) == 4 - expected_document_ids = [ - doc.id - for doc in [ - Document(content="b"), - Document(content="a"), - Document(content="c"), - Document(content="f", meta={"key": "value"}), - ] - ] - assert all(doc.id in expected_document_ids for doc in output["documents"]) - - @pytest.mark.unit - def test_sort_by_score_without_scores(self, caplog): - joiner = DocumentJoiner() - with caplog.at_level(logging.INFO): - documents = [Document(content="a"), Document(content="b", score=0.5)] - output = joiner.run([documents]) - assert "those with score=None were sorted as if they had a score of -infinity" in caplog.text - assert output["documents"] == documents[::-1] - - @pytest.mark.unit - def test_output_documents_not_sorted_by_score(self): - joiner = DocumentJoiner(sort_by_score=False) - documents_1 = [Document(content="a", score=0.1)] - documents_2 = [Document(content="d", score=0.2)] - output = joiner.run([documents_1, documents_2]) - assert output["documents"] == documents_1 + documents_2 diff --git a/test/preview/components/routers/test_file_router.py b/test/preview/components/routers/test_file_router.py deleted file mode 100644 index b513d95830..0000000000 --- a/test/preview/components/routers/test_file_router.py +++ /dev/null @@ -1,139 +0,0 @@ -import sys - -import pytest - -from haystack.preview.components.routers.file_type_router import FileTypeRouter -from haystack.preview.dataclasses import ByteStream - - -@pytest.mark.skipif( - sys.platform in ["win32", "cygwin"], - reason="Can't run on Windows Github CI, need access to registry to get mime types", -) -class TestFileTypeRouter: - @pytest.mark.unit - def test_run(self, preview_samples_path): - """ - Test if the component runs correctly in the simplest happy path. - """ - file_paths = [ - preview_samples_path / "txt" / "doc_1.txt", - preview_samples_path / "txt" / "doc_2.txt", - preview_samples_path / "audio" / "the context for this answer is here.wav", - preview_samples_path / "images" / "apple.jpg", - ] - - router = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"]) - output = router.run(sources=file_paths) - assert output - assert len(output["text/plain"]) == 2 - assert len(output["audio/x-wav"]) == 1 - assert len(output["image/jpeg"]) == 1 - assert not output["unclassified"] - - @pytest.mark.unit - def test_run_with_bytestreams(self, preview_samples_path): - """ - Test if the component runs correctly with ByteStream inputs. - """ - file_paths = [ - preview_samples_path / "txt" / "doc_1.txt", - preview_samples_path / "txt" / "doc_2.txt", - preview_samples_path / "audio" / "the context for this answer is here.wav", - preview_samples_path / "images" / "apple.jpg", - ] - mime_types = ["text/plain", "text/plain", "audio/x-wav", "image/jpeg"] - # Convert file paths to ByteStream objects and set metadata - byte_streams = [] - for path, mime_type in zip(file_paths, mime_types): - stream = ByteStream(path.read_bytes()) - - stream.metadata["content_type"] = mime_type - - byte_streams.append(stream) - - # add unclassified ByteStream - bs = ByteStream(b"unclassified content") - bs.metadata["content_type"] = "unknown_type" - byte_streams.append(bs) - - router = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"]) - output = router.run(sources=byte_streams) - assert output - assert len(output["text/plain"]) == 2 - assert len(output["audio/x-wav"]) == 1 - assert len(output["image/jpeg"]) == 1 - assert len(output.get("unclassified")) == 1 - - @pytest.mark.unit - def test_run_with_bytestreams_and_file_paths(self, preview_samples_path): - file_paths = [ - preview_samples_path / "txt" / "doc_1.txt", - preview_samples_path / "audio" / "the context for this answer is here.wav", - preview_samples_path / "txt" / "doc_2.txt", - preview_samples_path / "images" / "apple.jpg", - ] - mime_types = ["text/plain", "audio/x-wav", "text/plain", "image/jpeg"] - byte_stream_sources = [] - for path, mime_type in zip(file_paths, mime_types): - stream = ByteStream(path.read_bytes()) - stream.metadata["content_type"] = mime_type - byte_stream_sources.append(stream) - - mixed_sources = file_paths[:2] + byte_stream_sources[2:] - - router = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"]) - output = router.run(sources=mixed_sources) - assert len(output["text/plain"]) == 2 - assert len(output["audio/x-wav"]) == 1 - assert len(output["image/jpeg"]) == 1 - - @pytest.mark.unit - def test_no_files(self): - """ - Test that the component runs correctly when no files are provided. - """ - router = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"]) - output = router.run(sources=[]) - assert not output - - @pytest.mark.unit - def test_unlisted_extensions(self, preview_samples_path): - """ - Test that the component correctly handles files with non specified mime types. - """ - file_paths = [ - preview_samples_path / "txt" / "doc_1.txt", - preview_samples_path / "audio" / "ignored.mp3", - preview_samples_path / "audio" / "this is the content of the document.wav", - ] - router = FileTypeRouter(mime_types=["text/plain"]) - output = router.run(sources=file_paths) - assert len(output["text/plain"]) == 1 - assert "mp3" not in output - assert len(output["unclassified"]) == 2 - assert str(output["unclassified"][0]).endswith("ignored.mp3") - assert str(output["unclassified"][1]).endswith("this is the content of the document.wav") - - @pytest.mark.unit - def test_no_extension(self, preview_samples_path): - """ - Test that the component ignores files with no extension. - """ - file_paths = [ - preview_samples_path / "txt" / "doc_1.txt", - preview_samples_path / "txt" / "doc_2", - preview_samples_path / "txt" / "doc_2.txt", - ] - router = FileTypeRouter(mime_types=["text/plain"]) - output = router.run(sources=file_paths) - assert len(output["text/plain"]) == 2 - assert len(output["unclassified"]) == 1 - - @pytest.mark.unit - def test_unknown_mime_type(self): - """ - Test that the component handles files with unknown mime types. - """ - with pytest.raises(ValueError, match="Unknown mime type:"): - FileTypeRouter(mime_types=["type_invalid"]) diff --git a/test/preview/components/routers/test_metadata_router.py b/test/preview/components/routers/test_metadata_router.py deleted file mode 100644 index 5109f8db6c..0000000000 --- a/test/preview/components/routers/test_metadata_router.py +++ /dev/null @@ -1,35 +0,0 @@ -import pytest - -from haystack.preview import Document -from haystack.preview.components.routers.metadata_router import MetadataRouter - - -class TestMetadataRouter: - @pytest.mark.unit - def test_run(self): - rules = { - "edge_1": { - "operator": "AND", - "conditions": [ - {"field": "meta.created_at", "operator": ">=", "value": "2023-01-01"}, - {"field": "meta.created_at", "operator": "<", "value": "2023-04-01"}, - ], - }, - "edge_2": { - "operator": "AND", - "conditions": [ - {"field": "meta.created_at", "operator": ">=", "value": "2023-04-01"}, - {"field": "meta.created_at", "operator": "<", "value": "2023-07-01"}, - ], - }, - } - router = MetadataRouter(rules=rules) - documents = [ - Document(meta={"created_at": "2023-02-01"}), - Document(meta={"created_at": "2023-05-01"}), - Document(meta={"created_at": "2023-08-01"}), - ] - output = router.run(documents=documents) - assert output["edge_1"][0].meta["created_at"] == "2023-02-01" - assert output["edge_2"][0].meta["created_at"] == "2023-05-01" - assert output["unmatched"][0].meta["created_at"] == "2023-08-01" diff --git a/test/preview/components/routers/test_text_language_router.py b/test/preview/components/routers/test_text_language_router.py deleted file mode 100644 index c3e333ec2b..0000000000 --- a/test/preview/components/routers/test_text_language_router.py +++ /dev/null @@ -1,52 +0,0 @@ -import logging -import pytest - -from haystack.preview import Document -from haystack.preview.components.routers import TextLanguageRouter - - -class TestTextLanguageRouter: - @pytest.mark.unit - def test_non_string_input(self): - with pytest.raises(TypeError, match="TextLanguageRouter expects a str as input."): - classifier = TextLanguageRouter() - classifier.run(text=Document(content="This is an english sentence.")) - - @pytest.mark.unit - def test_list_of_string(self): - with pytest.raises(TypeError, match="TextLanguageRouter expects a str as input."): - classifier = TextLanguageRouter() - classifier.run(text=["This is an english sentence."]) - - @pytest.mark.unit - def test_empty_string(self): - classifier = TextLanguageRouter() - result = classifier.run(text="") - assert result == {"unmatched": ""} - - @pytest.mark.unit - def test_detect_language(self): - classifier = TextLanguageRouter() - detected_language = classifier.detect_language("This is an english sentence.") - assert detected_language == "en" - - @pytest.mark.unit - def test_route_to_en(self): - classifier = TextLanguageRouter() - english_sentence = "This is an english sentence." - result = classifier.run(text=english_sentence) - assert result == {"en": english_sentence} - - @pytest.mark.unit - def test_route_to_unmatched(self): - classifier = TextLanguageRouter() - german_sentence = "Ein deutscher Satz ohne Verb." - result = classifier.run(text=german_sentence) - assert result == {"unmatched": german_sentence} - - @pytest.mark.unit - def test_warning_if_no_language_detected(self, caplog): - with caplog.at_level(logging.WARNING): - classifier = TextLanguageRouter() - classifier.run(text=".") - assert "Langdetect cannot detect the language of text: ." in caplog.text diff --git a/test/preview/components/samplers/__init__.py b/test/preview/components/samplers/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/test/preview/components/samplers/test_top_p.py b/test/preview/components/samplers/test_top_p.py deleted file mode 100644 index 74d0e269db..0000000000 --- a/test/preview/components/samplers/test_top_p.py +++ /dev/null @@ -1,89 +0,0 @@ -import random - -import pytest - -from haystack.preview import Document, ComponentError -from haystack.preview.components.samplers.top_p import TopPSampler - - -class TestTopPSampler: - @pytest.mark.unit - def test_run_scores_from_metadata(self): - """ - Test if the component runs correctly with scores already in the metadata. - """ - sampler = TopPSampler(top_p=0.95, score_field="similarity_score") - docs = [ - Document(content="Berlin", meta={"similarity_score": -10.6}), - Document(content="Belgrade", meta={"similarity_score": -8.9}), - Document(content="Sarajevo", meta={"similarity_score": -4.6}), - ] - output = sampler.run(documents=docs) - docs = output["documents"] - assert len(docs) == 1 - assert docs[0].content == "Sarajevo" - - @pytest.mark.unit - def test_run_scores(self): - """ - Test if the component runs correctly with scores in the Document score field. - """ - sampler = TopPSampler(top_p=0.99) - docs = [ - Document(content="Berlin", score=-10.6), - Document(content="Belgrade", score=-8.9), - Document(content="Sarajevo", score=-4.6), - ] - - random.shuffle(docs) - sorted_scores = sorted([doc.score for doc in docs], reverse=True) - - # top_p = 0.99 will get the top 1 document - output = sampler.run(documents=docs) - docs_filtered = output["documents"] - assert len(docs_filtered) == 1 - assert docs_filtered[0].content == "Sarajevo" - - assert [doc.score for doc in docs_filtered] == sorted_scores[:1] - - @pytest.mark.unit - def test_run_scores_top_p_1(self): - """ - Test if the component runs correctly top_p=1. - """ - sampler = TopPSampler(top_p=1.0) - docs = [ - Document(content="Berlin", score=-10.6), - Document(content="Belgrade", score=-8.9), - Document(content="Sarajevo", score=-4.6), - ] - - random.shuffle(docs) - output = sampler.run(documents=docs) - docs_filtered = output["documents"] - assert len(docs_filtered) == len(docs) - assert docs_filtered[0].content == "Sarajevo" - - assert [doc.score for doc in docs_filtered] == sorted([doc.score for doc in docs], reverse=True) - - # Returns an empty list if no documents are provided - @pytest.mark.unit - def test_returns_empty_list_if_no_documents_are_provided(self): - sampler = TopPSampler() - output = sampler.run(documents=[]) - assert output["documents"] == [] - - @pytest.mark.unit - def test_run_scores_no_metadata_present(self): - """ - Test if the component runs correctly with scores missing from the metadata yet being specified in the - score_field. - """ - sampler = TopPSampler(top_p=0.95, score_field="similarity_score") - docs = [ - Document(content="Berlin", score=-10.6), - Document(content="Belgrade", score=-8.9), - Document(content="Sarajevo", score=-4.6), - ] - with pytest.raises(ComponentError, match="Score field 'similarity_score' not found"): - sampler.run(documents=docs) diff --git a/test/preview/components/websearch/__init__.py b/test/preview/components/websearch/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/test/preview/components/websearch/test_searchapi.py b/test/preview/components/websearch/test_searchapi.py deleted file mode 100644 index c1aef566b1..0000000000 --- a/test/preview/components/websearch/test_searchapi.py +++ /dev/null @@ -1,445 +0,0 @@ -import os -from unittest.mock import Mock, patch - -import pytest -from requests import Timeout, RequestException, HTTPError - -from haystack.preview import Document -from haystack.preview.components.websearch.searchapi import SearchApiError, SearchApiWebSearch - - -EXAMPLE_SEARCHAPI_RESPONSE = { - "search_metadata": { - "id": "search_Y16dWXw4JOrIwNjjvqoKNGlE", - "status": "Success", - "created_at": "2023-11-22T16:10:56Z", - "request_time_taken": 1.98, - "parsing_time_taken": 0.16, - "total_time_taken": 2.15, - "request_url": "https://www.google.com/search?q=Who+is+CEO+of+Microsoft%3F&oq=Who+is+CEO+of+Microsoft%3F&gl=us&hl=en&ie=UTF-8", - "html_url": "https://www.searchapi.io/api/v1/searches/search_Y16dWXw4JOrIwNjjvqoKNGlE.html", - "json_url": "https://www.searchapi.io/api/v1/searches/search_Y16dWXw4JOrIwNjjvqoKNGlE", - }, - "search_parameters": { - "engine": "google", - "q": "Who is CEO of Microsoft?", - "device": "desktop", - "google_domain": "google.com", - "hl": "en", - "gl": "us", - }, - "search_information": { - "query_displayed": "Who is CEO of Microsoft?", - "total_results": 429000000, - "time_taken_displayed": 0.48, - }, - "answer_box": { - "type": "organic_result", - "title": "Microsoft Corporation/CEO", - "answer": "Satya Nadella", - "answer_date": "Feb 4, 2014–", - "link": "https://www.google.com/search?sca_esv=584620230&gl=us&hl=en&q=Satya+Nadella&stick=H4sIAAAAAAAAAONgVuLSz9U3KDQxqMjKesRoyi3w8sc9YSmdSWtOXmNU4-IKzsgvd80rySypFJLgYoOy-KR4uJC08Sxi5Q1OLKlMVPBLTEnNyUkEALvb1RBWAAAA&sa=X&ved=2ahUKEwi89re3_9eCAxU4IUQIHfHeB6MQzIcDKAB6BAgyEAE", - "snippet": "Microsoft CEO Satya Nadella speaks during the OpenAI DevDay event on November 06, 2023 in San Francisco, California.", - "date": "1 day ago", - "organic_result": { - "title": "Microsoft CEO Satya Nadella's response to the OpenAI board ...", - "link": "https://fortune.com/2023/11/21/microsoft-ceo-satya-nadella-openai-ceo-sam-altman-move-fast-fix-things/#:~:text=Microsoft%20CEO%20Satya%20Nadella%20speaks,2023%20in%20San%20Francisco%2C%20California.", - "source": "Fortune", - "domain": "fortune.com", - "displayed_link": "https://fortune.com › 2023/11/21 › microsoft-ceo-satya-...", - }, - "people_also_search_for": [ - { - "title": "Sundar Pichai", - "link": "https://www.google.com/search?sca_esv=584620230&gl=us&hl=en&q=Sundar+Pichai&stick=H4sIAAAAAAAAAONgFuLQz9U3MCkuM1HiArEs01OKzU20-AJSi4rz84IzU1LLEyuLF7HyBpfmpSQWKQRkJmckZgIAJfaYezgAAAA&sa=X&ved=2ahUKEwi89re3_9eCAxU4IUQIHfHeB6MQxA16BAgnEAQ", - }, - { - "title": "Steve Ballmer", - "link": "https://www.google.com/search?sca_esv=584620230&gl=us&hl=en&q=Steve+Ballmer&stick=H4sIAAAAAAAAAONgFuLQz9U3MCkuM1ECs8yTssu0-AJSi4rz84IzU1LLEyuLF7HyBpeklqUqOCXm5OSmFgEA31ogfDYAAAA&sa=X&ved=2ahUKEwi89re3_9eCAxU4IUQIHfHeB6MQxA16BAgnEAY", - }, - { - "title": "Anupama Nadella", - "link": "https://www.google.com/search?sca_esv=584620230&gl=us&hl=en&q=Anupama+Nadella&stick=H4sIAAAAAAAAAONgFuLQz9U3MCkuM1Hi1U_XNzRMMjPMzTHMMtHiC0gtKs7PC85MSS1PrCxexMrvmFdakJibqOCXmJKak5MIAEx0yhM9AAAA&sa=X&ved=2ahUKEwi89re3_9eCAxU4IUQIHfHeB6MQxA16BAgnEAg", - }, - { - "title": "Zain Nadella", - "link": "https://www.google.com/search?sca_esv=584620230&gl=us&hl=en&q=Zain+Nadella&stick=H4sIAAAAAAAAAONgFuLQz9U3MCkuM1Hi1U_XNzRMMjMyKCgsj9fiC0gtKs7PC85MSS1PrCxexMoTlZiZp-CXmJKak5MIANDRqOs6AAAA&sa=X&ved=2ahUKEwi89re3_9eCAxU4IUQIHfHeB6MQxA16BAgnEAo", - }, - { - "title": "Bill Gates", - "link": "https://www.google.com/search?sca_esv=584620230&gl=us&hl=en&q=Bill+Gates&stick=H4sIAAAAAAAAAONgFuLQz9U3MCkuM1ECswzN80q0-AJSi4rz84IzU1LLEyuLF7FyOWXm5Ci4J5akFgMAF5_u-TMAAAA&sa=X&ved=2ahUKEwi89re3_9eCAxU4IUQIHfHeB6MQxA16BAgnEAw", - }, - { - "title": "Shantanu Narayen", - "link": "https://www.google.com/search?sca_esv=584620230&gl=us&hl=en&q=Shantanu+Narayen&stick=H4sIAAAAAAAAAONgFuLQz9U3MCkuM1HiArGMzC0ts5O0-AJSi4rz84IzU1LLEyuLF7EKBGck5pUk5pUq-CUWJVam5gEA2xdRszsAAAA&sa=X&ved=2ahUKEwi89re3_9eCAxU4IUQIHfHeB6MQxA16BAgnEA4", - }, - { - "title": "Paul Allen", - "link": "https://www.google.com/search?sca_esv=584620230&gl=us&hl=en&q=Paul+Allen&stick=H4sIAAAAAAAAAONgFuLQz9U3MCkuM1ECs0xLsnO1-AJSi4rz84IzU1LLEyuLF7FyBSSW5ig45uSk5gEA_4-yKDMAAAA&sa=X&ved=2ahUKEwi89re3_9eCAxU4IUQIHfHeB6MQxA16BAgnEBA", - }, - ], - }, - "knowledge_graph": { - "kgmid": "/m/0q40xjj", - "knowledge_graph_type": "People", - "title": "Satya Nadella", - "type": "CEO of Microsoft", - "description": "Satya Narayana Nadella is an Indian-American business executive. He is the executive chairman and CEO of Microsoft, succeeding Steve Ballmer in 2014 as CEO and John W. Thompson in 2021 as chairman.", - "source": {"name": "Wikipedia", "link": "https://en.wikipedia.org/wiki/Satya_Nadella"}, - "born": "August 19, 1967 (age 56 years), Hyderabad, India", - "born_links": [ - { - "text": "Hyderabad, India", - "link": "https://www.google.com/search?sca_esv=584620230&gl=us&hl=en&q=Hyderabad&si=ALGXSlZS0YT-iRe81F2cKC9lM9KWTK4y0m5Atx8g9YliNNw2meVELJr66A46Jmr2L7YaEMWXarsN12T-Vg9bXBeu7mCHCG-SpT-gWQmluIDs5SvdST1r6rBUhcAOclNosjy4RgkGlWnecyHsBen2Ttz-NbCqTmTwwPK9ro0lfOFPb0CUDvLAkTbBXx4xNX7WWUJ19n0EWeuA&sa=X&ved=2ahUKEwi89re3_9eCAxU4IUQIHfHeB6MQmxMoAHoECGUQAg", - } - ], - "awards": "Padma Bhushan, CNN-IBN Indian of the Year Global Indian", - "awards_links": [ - { - "text": "Padma Bhushan", - "link": "https://www.google.com/search?sca_esv=584620230&gl=us&hl=en&q=Padma+Bhushan&si=ALGXSlYh1-GEPndq7qMo--O-TPixQtNN4JMroSxgItz5kq0stCyOa5BGWGIYt20KbMd-zdQdvwREsU7qSkWcyv0yzHS195H46le5meMq90to5z-nIHo4evgG3koKwps5uC-gu8Huemxmq6P1usjVEj5YR9okGopoUaOxuuyZP-isnQAmC6otzjnjf1O9jMuQObZmAnl2HH7coBXCHbIx1QvAHw1KZOYyJKPnYhWaYgqfQo7yF5BOVVLXvtr_8FhnFIxxl7f_V2B6&sa=X&ved=2ahUKEwi89re3_9eCAxU4IUQIHfHeB6MQmxMoAHoECF8QAg", - }, - { - "text": "CNN-IBN Indian of the Year Global Indian", - "link": "https://www.google.com/search?sca_esv=584620230&gl=us&hl=en&q=CNN-IBN+Indian+of+the+Year+Global+Indian&si=ALGXSlZZLz93Q5j8HVkpXyxpTaoqXw8cocmoi-DFAGsSj5diF8YzT48GvLer52UWWyGCjf3yeWD9YQzPqUV-LEVPLmirdkrJ_7HPexciHWOKnyaMVi0vXdKPSwvc8pE4fD3qmgVyw7qAFoNmy-T-U6OlosYKKVbf9CZnaOonmPhLRRFHGEEmKVtb_0FdKkXeUE2RIDgUJ1n1LWZoTeporPHOj4JfKSJADc-hymzzDEb5-uW3KxQtTdv_GJNMOoleFxqH9cvObQvW0_NvpfHZcThW9b_9g1BXjLfozVqh6hjRTbb40p5vu5e9Oi4sNqxtACf4Xoys_QX5&sa=X&ved=2ahUKEwi89re3_9eCAxU4IUQIHfHeB6MQmxMoAXoECF8QAw", - }, - ], - "nominations": "CNN-IBN Indian of the Year Global Indian", - "nominations_links": [ - { - "text": "CNN-IBN Indian of the Year Global Indian", - "link": "https://www.google.com/search?sca_esv=584620230&gl=us&hl=en&q=CNN-IBN+Indian+of+the+Year+Global+Indian&si=ALGXSlZZLz93Q5j8HVkpXyxpTaoqXw8cocmoi-DFAGsSj5diF8YzT48GvLer52UWWyGCjf3yeWD9YQzPqUV-LEVPLmirdkrJ_7HPexciHWOKnyaMVlh5LgokSYRM8a-Dib-kzfIaD6Uw_x_3lxo6j3NNKQbuBs4v4kkSCjL68joimLMo16eCX83PFrnvSsVKsgu6aFRoBYQt5p5NRofNfBXtVt2jzFVAWh23VsBHyyAxOuC2aQmgvKp-FGYymourIbHCdJ3rcx-Z&sa=X&ved=2ahUKEwi89re3_9eCAxU4IUQIHfHeB6MQmxMoAHoECGIQAg", - } - ], - "books": "Hit Refresh: The Quest to Rediscover Microsoft's Soul and Imagine a Better Future for Everyone", - "books_links": [ - { - "text": "Hit Refresh: The Quest to Rediscover Microsoft's Soul and Imagine a Better Future for Everyone", - "link": "https://www.google.com/search?sca_esv=584620230&gl=us&hl=en&q=Hit+Refresh&si=ALGXSlZZLz93Q5j8HVkpXyxpTaoqXw8cocmoi-DFAGsSj5diFzM3kSV8cu0gYZuy4n6At7XJ8qKh8mnRaXfDbxUaZoS_kPW87tGFHpw6B9zAS2a52vwJDx-fkzytheyPXaMQENZSl3bwqC9Nz3bqn7-Pglqh0Bik5Ow9AdVr2XI8mdVktN4SkCIaPE4qQfjAurt8rjUVyQzu3OFQx04nfPH3Gv7vP8aKqg%3D%3D&sa=X&ved=2ahUKEwi89re3_9eCAxU4IUQIHfHeB6MQmxMoAHoECGEQAg", - } - ], - "children": "Zain Nadella, Divya Nadella, Tara Nadella", - "children_links": [ - { - "text": "Zain Nadella", - "link": "https://www.google.com/search?sca_esv=584620230&gl=us&hl=en&q=Zain+Nadella&si=ALGXSlZZLz93Q5j8HVkpXyxpTaoqXw8cocmoi-DFAGsSj5diFxtguEvrMR1GmF2uy_-DLcVXwF5gosIuQudQPkad9bBUZxVKOG9PFdHXdEGQHrfXekG0E0x_raEKuDnHD6kk8_HfD3LZ57PWZ3Zyz0uhKPE15DfvA42IpAByWbms0fsgRw5IFCWwB5XMd3WM5U8KKsgeb_DmdoooQ_k3RrxO57jTcm5ZwgDlpBpGq0wj2Ksc2A65RQvA8NPJtpEqDcvEpJ4xWQ_tM_rHduCXRfsv9XFr84DzwA%3D%3D&sa=X&ved=2ahUKEwi89re3_9eCAxU4IUQIHfHeB6MQmxMoAHoECGQQAg", - }, - { - "text": "Divya Nadella", - "link": "https://www.google.com/search?sca_esv=584620230&gl=us&hl=en&q=Divya+Nadella&si=ALGXSlZZLz93Q5j8HVkpXyxpTaoqXw8cocmoi-DFAGsSj5diFwYr_pFPi4_6apkHPz96V-E6wDawAGH_i6kZL7ZB-ETzV3LLESN1a8BgFguu3LOpz1qAQypmcVosQxCFWSJVexciDel34yrgWJmUu5bY2zzEmu1h95LQ35yUDkf6Mqcn-TiwyLu7OzGYkw6D9P4kNkS2D3gNPnRZb6vQJbqdayQg-wgn-LG2BmwR-RntneXFgSSZgotziGaY96UzeZ0zgRWYp6LAKlRqlTbeDeCbDDY2_VIWjQ%3D%3D&sa=X&ved=2ahUKEwi89re3_9eCAxU4IUQIHfHeB6MQmxMoAXoECGQQAw", - }, - { - "text": "Tara Nadella", - "link": "https://www.google.com/search?sca_esv=584620230&gl=us&hl=en&q=Tara+Nadella&si=ALGXSlZZLz93Q5j8HVkpXyxpTaoqXw8cocmoi-DFAGsSj5diF465A_RPTnaELE1D-l5XgaKmBEpoAyayrOAdoXqBSLZ8Qu5UB1hBz6xLN4I1DdUSzqN0G0e9_8lfDbD_Qnx2uLJL_3XUNJ3gPrjCNvCyYeR9a9wkCnMBLchfUhVji9EHiobO4WgdWkxKd44YXHxfMBIYEek8OfbdUx9tplETPYtu7X1HRtGzqp8lXsQ6Vacj-aT7K6Xw0psbP4NXwHRQ71MYjLS-A5_VpSnitGScPsP-1m41Kg%3D%3D&sa=X&ved=2ahUKEwi89re3_9eCAxU4IUQIHfHeB6MQmxMoAnoECGQQBA", - }, - ], - "education": "University of Wisconsin-Milwaukee (1990), MORE", - "education_links": [ - { - "text": "University of Wisconsin-Milwaukee", - "link": "https://www.google.com/search?sca_esv=584620230&gl=us&hl=en&q=University+of+Wisconsin+Milwaukee&si=ALGXSlYh1-GEPndq7qMo--O-TPixQtNN4JMroSxgItz5kq0stDerMl6ZWDVLeoc_9LMxC6poOvWKyPlaxlQHHC7l9sV2e_sYZ2w92bas10emnFKqvF8PcMhCIIHCiTbdtg6nHIA-ihu0l0dNJtl3ZXuRejodvwikfjAsz-cGgFCLkxoi_eMM95SSZ77VXB0gP7fPTA6q__pIRK7T6ZfiSyM2xTbDt3YUvrWFmx5LBSJwRd2K1f0DK6sGaIa3ozdQOGvGXZkTOTLEG_a2ssbGBTX4MyU4cHmLsvW-Gfpq-makl3esSS7fQTc%3D&sa=X&ved=2ahUKEwi89re3_9eCAxU4IUQIHfHeB6MQmxMoAHoECGAQAg", - }, - { - "text": "MORE", - "link": "https://www.google.com/search?sca_esv=584620230&gl=us&hl=en&q=satya+nadella+education&stick=H4sIAAAAAAAAAOPgE-LSz9U3KDQxqMjK0pLOTrbSL0jNL8hJBVJFxfl5VqkppcmJJZn5eYtYxYsTSyoTFfISU1JzchIV4DIAcrWm-UUAAAA&sa=X&ved=2ahUKEwi89re3_9eCAxU4IUQIHfHeB6MQ44YBKAF6BAhgEAM", - }, - ], - "full_name": "Satya Narayana Nadella", - "profiles": [ - {"name": "LinkedIn", "link": "https://www.linkedin.com/in/satyanadella"}, - {"name": "Twitter", "link": "https://twitter.com/satyanadella"}, - ], - }, - "organic_results": [ - { - "position": 1, - "title": "Satya Nadella - Stories", - "link": "https://news.microsoft.com/exec/satya-nadella/", - "source": "Microsoft", - "domain": "news.microsoft.com", - "displayed_link": "https://news.microsoft.com › exec › satya-nadella", - "snippet": "Satya Nadella is Chairman and Chief Executive Officer of Microsoft. Before being named CEO in February 2014, Nadella held leadership roles in both ...", - "snippet_highlighted_words": ["Satya Nadella"], - "cached_page_link": "https://webcache.googleusercontent.com/search?q=cache:jTiZ69Cck7EJ:https://news.microsoft.com/exec/satya-nadella/&hl=en&gl=us", - }, - { - "position": 2, - "title": "Satya Nadella", - "link": "https://en.wikipedia.org/wiki/Satya_Nadella", - "source": "Wikipedia", - "domain": "en.wikipedia.org", - "displayed_link": "https://en.wikipedia.org › wiki › Satya_Nadella", - "snippet": "Satya Narayana Nadella is an Indian-American business executive. He is the executive chairman and CEO of Microsoft, succeeding Steve Ballmer in 2014 as CEO ...", - "snippet_highlighted_words": ["Satya Narayana Nadella"], - "sitelinks": { - "inline": [ - { - "title": "Manipal Institute of Technology", - "link": "https://en.wikipedia.org/wiki/Manipal_Institute_of_Technology", - }, - { - "title": "University of Wisconsin", - "link": "https://en.wikipedia.org/wiki/University_of_Wisconsin%E2%80%93Milwaukee", - }, - {"title": "S. Somasegar", "link": "https://en.wikipedia.org/wiki/S._Somasegar"}, - ] - }, - "cached_page_link": "https://webcache.googleusercontent.com/search?q=cache:Tgw93hG0PnoJ:https://en.wikipedia.org/wiki/Satya_Nadella&hl=en&gl=us", - }, - { - "position": 3, - "title": "Satya Nadella", - "link": "https://www.linkedin.com/in/satyanadella", - "source": "LinkedIn · Satya Nadella", - "domain": "www.linkedin.com", - "displayed_link": "10.5M+ followers", - "snippet": "As chairman and CEO of Microsoft, I define my mission and that of my company as empowering… | Learn more about Satya Nadella's work experience, education, ...", - "snippet_highlighted_words": ["Satya Nadella's"], - }, - { - "position": 4, - "title": "Who is Satya Nadella, Family, Salary, Education, Net Worth ...", - "link": "https://www.business-standard.com/about/who-is-satya-nadella", - "source": "Business Standard", - "domain": "www.business-standard.com", - "displayed_link": "https://www.business-standard.com › about › who-is-s...", - "snippet": "Satya Narayana Nadella is the chief executive officer (CEO) of Microsoft. Under him, Microsoft has more cloud computing revenue than Google, more subscribers ...", - "snippet_highlighted_words": ["Satya Narayana Nadella"], - "cached_page_link": "https://webcache.googleusercontent.com/search?q=cache:yQ0bmLSmP8gJ:https://www.business-standard.com/about/who-is-satya-nadella&hl=en&gl=us", - }, - { - "position": 5, - "title": "Satya Nadella (@satyanadella) / X", - "link": "https://twitter.com/satyanadella", - "source": "Twitter · satyanadella", - "domain": "twitter.com", - "displayed_link": "3.1M+ followers", - "snippet": "Chairman and CEO of Microsoft Corporation.", - "snippet_highlighted_words": ["CEO of Microsoft"], - "cached_page_link": "https://webcache.googleusercontent.com/search?q=cache:dEJiGKzwLfkJ:https://twitter.com/satyanadella&hl=en&gl=us", - }, - { - "position": 6, - "title": "Satya Nadella | Biography & Facts", - "link": "https://www.britannica.com/biography/Satya-Nadella", - "source": "Britannica", - "domain": "www.britannica.com", - "displayed_link": "https://www.britannica.com › biography › Satya-Nadella", - "snippet": "Satya Nadella (born August 19, 1967, Hyderabad, India) Indian-born business executive who was CEO of the computer software company Microsoft (2014– ).", - "snippet_highlighted_words": ["Satya Nadella"], - "cached_page_link": "https://webcache.googleusercontent.com/search?q=cache:a0S8ke4I9qgJ:https://www.britannica.com/biography/Satya-Nadella&hl=en&gl=us", - }, - { - "position": 7, - "title": "Satya Nadella", - "link": "https://www.forbes.com/profile/satya-nadella/", - "source": "Forbes", - "domain": "www.forbes.com", - "displayed_link": "https://www.forbes.com › profile › satya-nadella", - "snippet": "Satya Nadella replaced billionaire Steve Ballmer as Microsoft CEO in 2014. Prior to that, Nadella was Microsoft EVP of the cloud and enterprise group.", - "snippet_highlighted_words": ["Satya Nadella"], - "cached_page_link": "https://webcache.googleusercontent.com/search?q=cache:q_CXTYNnHSMJ:https://www.forbes.com/profile/satya-nadella/&hl=en&gl=us", - }, - { - "position": 8, - "title": "5 Facts You Didn't Know About Microsoft CEO Satya Nadella", - "link": "https://in.benzinga.com/content/35911756/5-facts-you-didnt-know-about-microsoft-ceo-satya-nadella", - "source": "Benzinga", - "domain": "in.benzinga.com", - "displayed_link": "https://in.benzinga.com › content › 5-facts-you-didnt-...", - "snippet": "Satya Nadella's journey at Microsoft underscores the importance of diverse experiences in shaping effective and empathetic leadership in the ...", - "snippet_highlighted_words": ["Satya Nadella's"], - "date": "8 hours ago", - "cached_page_link": "https://webcache.googleusercontent.com/search?q=cache:hCbtJUTgvEQJ:https://in.benzinga.com/content/35911756/5-facts-you-didnt-know-about-microsoft-ceo-satya-nadella&hl=en&gl=us", - }, - { - "position": 9, - "title": "Microsoft CEO Satya Nadella: Q&A - The Wall Street Journal", - "link": "https://www.wsj.com/video/microsoft-ceo-satya-nadella-qa/41D02815-935C-421D-8021-5E1BFD3DDE84", - "source": "Wall Street Journal", - "domain": "www.wsj.com", - "displayed_link": "https://www.wsj.com › video › microsoft-ceo-satya-nadel...", - "snippet": "Microsoft CEO Satya Nadella talks about his biggest accomplishment, how to make successful acquisitions and how the tech industry could improve its image ...", - "snippet_highlighted_words": ["Microsoft CEO"], - "video": {"source": "The Wall Street Journal", "channel": "The Wall Street Journal", "date": "Feb 1, 2019"}, - }, - ], - "related_questions": [ - { - "question": "Who is the real CEO of Microsoft?", - "answer": "Satya Nadella is Chairman and Chief Executive Officer of Microsoft.", - "answer_highlight": "Satya Nadella", - "source": { - "title": "Satya Nadella - Stories - Microsoft News", - "link": "https://news.microsoft.com/exec/satya-nadella/#:~:text=Satya%20Nadella%20is%20Chairman%20and%20Chief%20Executive%20Officer%20of%20Microsoft.", - "source": "Microsoft", - "domain": "news.microsoft.com", - "displayed_link": "https://news.microsoft.com › exec › satya-nadella", - }, - "search": { - "title": "Search for: Who is the real CEO of Microsoft?", - "link": "https://www.google.com/search?sca_esv=584620230&gl=us&hl=en&q=Who+is+the+real+CEO+of+Microsoft%3F&sa=X&ved=2ahUKEwi89re3_9eCAxU4IUQIHfHeB6MQzmd6BAgeEAY", - }, - }, - { - "question": "Who is the CEO of Microsoft 2023?", - "answer": "Microsoft Corp. chief executive officer Satya Nadella signaled that he'd be open to Sam Altman going back to OpenAI, rather than joining his company as part of a surprise move announced over the weekend.", - "date": "2 days ago", - "source": { - "title": "Microsoft CEO Satya Nadella signals willingness to have Sam Altman ...", - "link": "https://economictimes.indiatimes.com/tech/technology/microsoft-ceo-satya-nadella-signals-willingness-to-have-sam-altman-rejoin-openai/articleshow/105370026.cms#:~:text=Microsoft%20Corp.%20chief%20executive%20officer,move%20announced%20over%20the%20weekend.", - "source": "indiatimes.com", - "domain": "economictimes.indiatimes.com", - "displayed_link": "https://economictimes.indiatimes.com › tech › articleshow", - }, - "search": { - "title": "Search for: Who is the CEO of Microsoft 2023?", - "link": "https://www.google.com/search?sca_esv=584620230&gl=us&hl=en&q=Who+is+the+CEO+of+Microsoft+2023%3F&sa=X&ved=2ahUKEwi89re3_9eCAxU4IUQIHfHeB6MQzmd6BAgcEAY", - }, - }, - { - "question": "How many degrees does Satya Nadella have?", - "answer": "He earned a bachelor's degree in electrical engineering from Mangalore University, a master's degree in computer science from the University of Wisconsin – Milwaukee and a master's degree in business administration from the University of Chicago.", - "source": { - "title": "Satya Nadella - Institutional - BlackRock", - "link": "https://www.blackrock.com/institutions/en-zz/biographies/satya-nadella#:~:text=He%20earned%20a%20bachelor's%20degree,from%20the%20University%20of%20Chicago.", - "source": "blackrock.com", - "domain": "www.blackrock.com", - "displayed_link": "https://www.blackrock.com › en-zz › biographies › satya...", - }, - "search": { - "title": "Search for: How many degrees does Satya Nadella have?", - "link": "https://www.google.com/search?sca_esv=584620230&gl=us&hl=en&q=How+many+degrees+does+Satya+Nadella+have%3F&sa=X&ved=2ahUKEwi89re3_9eCAxU4IUQIHfHeB6MQzmd6BAgdEAY", - }, - }, - { - "question": "How old is Satya Nadella?", - "answer_highlight": "56 years (August 19, 1967)", - "entity": {"subject": "Satya Nadella", "attribute": "Age", "value": "56 years (August 19, 1967)"}, - "search": { - "title": "Search for: How old is Satya Nadella?", - "link": "https://www.google.com/search?sca_esv=584620230&gl=us&hl=en&q=How+old+is+Satya+Nadella%3F&sa=X&ved=2ahUKEwi89re3_9eCAxU4IUQIHfHeB6MQzmd6BAgREAY", - }, - }, - ], - "related_searches": [ - { - "query": "Who is ceo of microsoft wife", - "link": "https://www.google.com/search?sca_esv=584620230&gl=us&hl=en&q=Who+is+ceo+of+microsoft+wife&sa=X&ved=2ahUKEwi89re3_9eCAxU4IUQIHfHeB6MQ1QJ6BAhWEAE", - }, - { - "query": "Who is ceo of microsoft and microsoft", - "link": "https://www.google.com/search?sca_esv=584620230&gl=us&hl=en&q=Who+is+ceo+of+microsoft+and+microsoft&sa=X&ved=2ahUKEwi89re3_9eCAxU4IUQIHfHeB6MQ1QJ6BAhVEAE", - }, - { - "query": "Who is ceo of microsoft wikipedia", - "link": "https://www.google.com/search?sca_esv=584620230&gl=us&hl=en&q=Who+is+ceo+of+microsoft+wikipedia&sa=X&ved=2ahUKEwi89re3_9eCAxU4IUQIHfHeB6MQ1QJ6BAhUEAE", - }, - { - "query": "microsoft founder", - "link": "https://www.google.com/search?sca_esv=584620230&gl=us&hl=en&q=Microsoft+founder&sa=X&ved=2ahUKEwi89re3_9eCAxU4IUQIHfHeB6MQ1QJ6BAhSEAE", - }, - { - "query": "Who is ceo of microsoft 2020", - "link": "https://www.google.com/search?sca_esv=584620230&gl=us&hl=en&q=Who+is+ceo+of+microsoft+2020&sa=X&ved=2ahUKEwi89re3_9eCAxU4IUQIHfHeB6MQ1QJ6BAhTEAE", - }, - { - "query": "satya nadella net worth", - "link": "https://www.google.com/search?sca_esv=584620230&gl=us&hl=en&q=Satya+Nadella+net+worth&sa=X&ved=2ahUKEwi89re3_9eCAxU4IUQIHfHeB6MQ1QJ6BAhREAE", - }, - { - "query": "ceo of microsoft salary", - "link": "https://www.google.com/search?sca_esv=584620230&gl=us&hl=en&q=CEO+of+Microsoft+salary&sa=X&ved=2ahUKEwi89re3_9eCAxU4IUQIHfHeB6MQ1QJ6BAhQEAE", - }, - { - "query": "ceo of apple", - "link": "https://www.google.com/search?sca_esv=584620230&gl=us&hl=en&q=CEO+of+Apple&sa=X&ved=2ahUKEwi89re3_9eCAxU4IUQIHfHeB6MQ1QJ6BAhXEAE", - }, - ], -} - - -@pytest.fixture -def mock_searchapi_search_result(): - with patch("haystack.preview.components.websearch.searchapi.requests.get") as mock_get: - mock_get.return_value = Mock(status_code=200, json=lambda: EXAMPLE_SEARCHAPI_RESPONSE) - yield mock_get - - -class TestSearchApiSearchAPI: - @pytest.mark.unit - def test_init_fail_wo_api_key(self, monkeypatch): - monkeypatch.delenv("SEARCHAPI_API_KEY", raising=False) - with pytest.raises(ValueError, match="SearchApiWebSearch expects an API key"): - SearchApiWebSearch() - - @pytest.mark.unit - def test_to_dict(self): - component = SearchApiWebSearch( - api_key="api_key", top_k=10, allowed_domains=["testdomain.com"], search_params={"param": "test params"} - ) - data = component.to_dict() - assert data == { - "type": "haystack.preview.components.websearch.searchapi.SearchApiWebSearch", - "init_parameters": { - "top_k": 10, - "allowed_domains": ["testdomain.com"], - "search_params": {"param": "test params"}, - }, - } - - @pytest.mark.unit - @pytest.mark.parametrize("top_k", [1, 5, 7]) - def test_web_search_top_k(self, mock_searchapi_search_result, top_k: int): - ws = SearchApiWebSearch(api_key="api_key", top_k=top_k) - results = ws.run(query="Who is CEO of Microsoft?") - documents = results["documents"] - links = results["links"] - assert len(documents) == len(links) == top_k - assert all(isinstance(doc, Document) for doc in documents) - assert all(isinstance(link, str) for link in links) - assert all(link.startswith("http") for link in links) - - @pytest.mark.unit - @patch("requests.get") - def test_timeout_error(self, mock_get): - mock_get.side_effect = Timeout - ws = SearchApiWebSearch(api_key="api_key") - - with pytest.raises(TimeoutError): - ws.run(query="Who is CEO of Microsoft?") - - @pytest.mark.unit - @patch("requests.get") - def test_request_exception(self, mock_get): - mock_get.side_effect = RequestException - ws = SearchApiWebSearch(api_key="api_key") - - with pytest.raises(SearchApiError): - ws.run(query="Who is CEO of Microsoft?") - - @pytest.mark.unit - @patch("requests.get") - def test_bad_response_code(self, mock_get): - mock_response = mock_get.return_value - mock_response.status_code = 404 - mock_response.raise_for_status.side_effect = HTTPError - ws = SearchApiWebSearch(api_key="api_key") - - with pytest.raises(SearchApiError): - ws.run(query="Who is CEO of Microsoft?") - - @pytest.mark.skipif( - not os.environ.get("SEARCHAPI_API_KEY", None), - reason="Export an env var called SEARCHAPI_API_KEY containing the SearchApi API key to run this test.", - ) - @pytest.mark.integration - def test_web_search(self): - ws = SearchApiWebSearch(api_key=os.environ.get("SEARCHAPI_API_KEY", None), top_k=10) - results = ws.run(query="Who is CEO of Microsoft?") - documents = results["documents"] - links = results["links"] - assert len(documents) == len(links) == 10 - assert all(isinstance(doc, Document) for doc in results) - assert all(isinstance(link, str) for link in links) - assert all(link.startswith("http") for link in links) diff --git a/test/preview/components/websearch/test_serperdev.py b/test/preview/components/websearch/test_serperdev.py deleted file mode 100644 index 48d8619d21..0000000000 --- a/test/preview/components/websearch/test_serperdev.py +++ /dev/null @@ -1,182 +0,0 @@ -import os -from unittest.mock import Mock, patch - -import pytest -from requests import Timeout, RequestException, HTTPError - -from haystack.preview import Document -from haystack.preview.components.websearch.serper_dev import SerperDevWebSearch, SerperDevError - - -EXAMPLE_SERPERDEV_RESPONSE = { - "searchParameters": { - "q": "Who is the boyfriend of Olivia Wilde?", - "gl": "us", - "hl": "en", - "autocorrect": True, - "type": "search", - }, - "organic": [ - { - "title": "Olivia Wilde embraces Jason Sudeikis amid custody battle, Harry Styles split - Page Six", - "link": "https://pagesix.com/2023/01/29/olivia-wilde-hugs-it-out-with-jason-sudeikis-after-harry-styles-split/", - "snippet": "Looks like Olivia Wilde and Jason Sudeikis are starting 2023 on good terms. Amid their highly publicized custody battle – and the actress' ...", - "date": "Jan 29, 2023", - "position": 1, - }, - { - "title": "Olivia Wilde Is 'Quietly Dating' Again Following Harry Styles Split: 'He Makes Her Happy'", - "link": "https://www.yahoo.com/now/olivia-wilde-quietly-dating-again-183844364.html", - "snippet": "Olivia Wilde is “quietly dating again” following her November 2022 split from Harry Styles, a source exclusively tells Life & Style.", - "date": "Feb 10, 2023", - "position": 2, - }, - { - "title": "Olivia Wilde and Harry Styles' Relationship Timeline: The Way They Were - Us Weekly", - "link": "https://www.usmagazine.com/celebrity-news/pictures/olivia-wilde-and-harry-styles-relationship-timeline/", - "snippet": "Olivia Wilde started dating Harry Styles after ending her years-long engagement to Jason Sudeikis — see their relationship timeline.", - "date": "Mar 10, 2023", - "imageUrl": "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcSgTcalNFvptTbYBiDXX55s8yCGfn6F1qbed9DAN16LvynTr9GayK5SPmY&s", - "position": 3, - }, - { - "title": "Olivia Wilde Is 'Ready to Date Again' After Harry Styles Split - Us Weekly", - "link": "https://www.usmagazine.com/celebrity-news/news/olivia-wilde-is-ready-to-date-again-after-harry-styles-split/", - "snippet": "Ready for love! Olivia Wilde is officially back on the dating scene following her split from her ex-boyfriend, Harry Styles.", - "date": "Mar 1, 2023", - "imageUrl": "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcRCRAeRy5sVE631ZctzbzuOF70xkIOHaTvh2K7dYvdiVBwALiKrIjpscok&s", - "position": 4, - }, - { - "title": "Harry Styles and Olivia Wilde's Definitive Relationship Timeline - Harper's Bazaar", - "link": "https://www.harpersbazaar.com/celebrity/latest/a35172115/harry-styles-olivia-wilde-relationship-timeline/", - "snippet": "November 2020: News breaks about Olivia splitting from fiancé Jason Sudeikis. ... In mid-November, news breaks of Olivia Wilde's split from Jason ...", - "date": "Feb 23, 2023", - "imageUrl": "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcRRqw3fvZOIGHEepxCc7yFAWYsS_v_1H6X-4nxyFJxdfRuFQw_BrI6JVzI&s", - "position": 5, - }, - { - "title": "Harry Styles and Olivia Wilde's Relationship Timeline - People", - "link": "https://people.com/music/harry-styles-olivia-wilde-relationship-timeline/", - "snippet": "Harry Styles and Olivia Wilde first met on the set of Don't Worry Darling and stepped out as a couple in January 2021. Relive all their biggest relationship ...", - "position": 6, - }, - { - "title": "Jason Sudeikis and Olivia Wilde's Relationship Timeline - People", - "link": "https://people.com/movies/jason-sudeikis-olivia-wilde-relationship-timeline/", - "snippet": "Jason Sudeikis and Olivia Wilde ended their engagement of seven years in 2020. Here's a complete timeline of their relationship.", - "date": "Mar 24, 2023", - "imageUrl": "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcSleZoXusQyJJe2WMgIuck_cVaJ8AE0_hU2QxsXzYvKANi55UQlv82yAVI&s", - "position": 7, - }, - { - "title": "Olivia Wilde's anger at ex-boyfriend Harry Styles: She resents him and thinks he was using her | Marca", - "link": "https://www.marca.com/en/lifestyle/celebrities/2023/02/23/63f779a4e2704e8d988b4624.html", - "snippet": "The two started dating after Wilde split up with actor Jason Sudeikisin 2020. However, their relationship came to an end last November.", - "date": "Feb 23, 2023", - "imageUrl": "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcQBgJF2mSnIWCvPrqUqM4WTI9xPNWPyLvHuune85swpB1yE_G8cy_7KRh0&s", - "position": 8, - }, - { - "title": "Olivia Wilde's dating history: Who has the actress dated? | The US Sun", - "link": "https://www.the-sun.com/entertainment/5221040/olivia-wildes-dating-history/", - "snippet": "AMERICAN actress Olivia Wilde started dating Harry Styles in January 2021 after breaking off her engagement the year prior.", - "date": "Nov 19, 2022", - "imageUrl": "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTpm8BToVFHJoH6yRggg0fLocLT9mt6lwsnRxFFDNdDGhDydzQiSKZ9__g&s", - "position": 9, - }, - ], - "relatedSearches": [ - {"query": "Harry Styles girlfriends in order"}, - {"query": "Harry Styles and Olivia Wilde engaged"}, - {"query": "Harry Styles and Olivia Wilde wedding"}, - {"query": "Who is Harry Styles married to"}, - {"query": "Jason Sudeikis Olivia Wilde relationship"}, - {"query": "Olivia Wilde and Jason Sudeikis kids"}, - {"query": "Olivia Wilde children"}, - {"query": "Harry Styles and Olivia Wilde age difference"}, - {"query": "Jason Sudeikis Olivia Wilde, Harry Styles"}, - ], -} - - -@pytest.fixture -def mock_serper_dev_search_result(): - with patch("haystack.preview.components.websearch.serper_dev.requests") as mock_run: - mock_run.post.return_value = Mock(status_code=200, json=lambda: EXAMPLE_SERPERDEV_RESPONSE) - yield mock_run - - -class TestSerperDevSearchAPI: - @pytest.mark.unit - def test_init_fail_wo_api_key(self, monkeypatch): - monkeypatch.delenv("SERPERDEV_API_KEY", raising=False) - with pytest.raises(ValueError, match="SerperDevWebSearch expects an API key"): - SerperDevWebSearch() - - @pytest.mark.unit - def test_to_dict(self): - component = SerperDevWebSearch( - api_key="test_key", top_k=10, allowed_domains=["test.com"], search_params={"param": "test"} - ) - data = component.to_dict() - assert data == { - "type": "haystack.preview.components.websearch.serper_dev.SerperDevWebSearch", - "init_parameters": {"top_k": 10, "allowed_domains": ["test.com"], "search_params": {"param": "test"}}, - } - - @pytest.mark.unit - @pytest.mark.parametrize("top_k", [1, 5, 7]) - def test_web_search_top_k(self, mock_serper_dev_search_result, top_k: int): - ws = SerperDevWebSearch(api_key="some_invalid_key", top_k=top_k) - results = ws.run(query="Who is the boyfriend of Olivia Wilde?") - documents = results["documents"] - links = results["links"] - assert len(documents) == len(links) == top_k - assert all(isinstance(doc, Document) for doc in documents) - assert all(isinstance(link, str) for link in links) - assert all(link.startswith("http") for link in links) - - @pytest.mark.unit - @patch("requests.post") - def test_timeout_error(self, mock_post): - mock_post.side_effect = Timeout - ws = SerperDevWebSearch(api_key="some_invalid_key") - - with pytest.raises(TimeoutError): - ws.run(query="Who is the boyfriend of Olivia Wilde?") - - @pytest.mark.unit - @patch("requests.post") - def test_request_exception(self, mock_post): - mock_post.side_effect = RequestException - ws = SerperDevWebSearch(api_key="some_invalid_key") - - with pytest.raises(SerperDevError): - ws.run(query="Who is the boyfriend of Olivia Wilde?") - - @pytest.mark.unit - @patch("requests.post") - def test_bad_response_code(self, mock_post): - mock_response = mock_post.return_value - mock_response.status_code = 404 - mock_response.raise_for_status.side_effect = HTTPError - ws = SerperDevWebSearch(api_key="some_invalid_key") - - with pytest.raises(SerperDevError): - ws.run(query="Who is the boyfriend of Olivia Wilde?") - - @pytest.mark.skipif( - not os.environ.get("SERPERDEV_API_KEY", None), - reason="Export an env var called SERPERDEV_API_KEY containing the SerperDev API key to run this test.", - ) - @pytest.mark.integration - def test_web_search(self): - ws = SerperDevWebSearch(api_key=os.environ.get("SERPERDEV_API_KEY", None), top_k=10) - results = ws.run(query="Who is the boyfriend of Olivia Wilde?") - documents = results["documents"] - links = results["documents"] - assert len(documents) == len(links) == 10 - assert all(isinstance(doc, Document) for doc in results) - assert all(isinstance(link, str) for link in links) - assert all(link.startswith("http") for link in links) diff --git a/test/preview/components/writers/__init__.py b/test/preview/components/writers/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/test/preview/components/writers/test_document_writer.py b/test/preview/components/writers/test_document_writer.py deleted file mode 100644 index ed5b9a4119..0000000000 --- a/test/preview/components/writers/test_document_writer.py +++ /dev/null @@ -1,106 +0,0 @@ -import pytest - -from haystack.preview import Document, DeserializationError -from haystack.preview.testing.factory import document_store_class -from haystack.preview.components.writers.document_writer import DocumentWriter -from haystack.preview.document_stores import DuplicatePolicy -from haystack.preview.document_stores.in_memory import InMemoryDocumentStore - - -class TestDocumentWriter: - @pytest.mark.unit - def test_to_dict(self): - mocked_docstore_class = document_store_class("MockedDocumentStore") - component = DocumentWriter(document_store=mocked_docstore_class()) - data = component.to_dict() - assert data == { - "type": "haystack.preview.components.writers.document_writer.DocumentWriter", - "init_parameters": { - "document_store": { - "type": "haystack.preview.testing.factory.MockedDocumentStore", - "init_parameters": {}, - }, - "policy": "FAIL", - }, - } - - @pytest.mark.unit - def test_to_dict_with_custom_init_parameters(self): - mocked_docstore_class = document_store_class("MockedDocumentStore") - component = DocumentWriter(document_store=mocked_docstore_class(), policy=DuplicatePolicy.SKIP) - data = component.to_dict() - assert data == { - "type": "haystack.preview.components.writers.document_writer.DocumentWriter", - "init_parameters": { - "document_store": { - "type": "haystack.preview.testing.factory.MockedDocumentStore", - "init_parameters": {}, - }, - "policy": "SKIP", - }, - } - - @pytest.mark.unit - def test_from_dict(self): - mocked_docstore_class = document_store_class("MockedDocumentStore") - data = { - "type": "haystack.preview.components.writers.document_writer.DocumentWriter", - "init_parameters": { - "document_store": { - "type": "haystack.preview.testing.factory.MockedDocumentStore", - "init_parameters": {}, - }, - "policy": "SKIP", - }, - } - component = DocumentWriter.from_dict(data) - assert isinstance(component.document_store, mocked_docstore_class) - assert component.policy == DuplicatePolicy.SKIP - - @pytest.mark.unit - def test_from_dict_without_docstore(self): - data = {"type": "DocumentWriter", "init_parameters": {}} - with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"): - DocumentWriter.from_dict(data) - - @pytest.mark.unit - def test_from_dict_without_docstore_type(self): - data = {"type": "DocumentWriter", "init_parameters": {"document_store": {"init_parameters": {}}}} - with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"): - DocumentWriter.from_dict(data) - - @pytest.mark.unit - def test_from_dict_nonexisting_docstore(self): - data = { - "type": "DocumentWriter", - "init_parameters": {"document_store": {"type": "NonexistingDocumentStore", "init_parameters": {}}}, - } - with pytest.raises(DeserializationError, match="DocumentStore of type 'NonexistingDocumentStore' not found."): - DocumentWriter.from_dict(data) - - @pytest.mark.unit - def test_run(self): - document_store = InMemoryDocumentStore() - writer = DocumentWriter(document_store) - documents = [ - Document(content="This is the text of a document."), - Document(content="This is the text of another document."), - ] - - result = writer.run(documents=documents) - assert result["documents_written"] == 2 - - @pytest.mark.unit - def test_run_skip_policy(self): - document_store = InMemoryDocumentStore() - writer = DocumentWriter(document_store, policy=DuplicatePolicy.SKIP) - documents = [ - Document(content="This is the text of a document."), - Document(content="This is the text of another document."), - ] - - result = writer.run(documents=documents) - assert result["documents_written"] == 2 - - result = writer.run(documents=documents) - assert result["documents_written"] == 0 diff --git a/test/preview/conftest.py b/test/preview/conftest.py deleted file mode 100644 index 3a2f166120..0000000000 --- a/test/preview/conftest.py +++ /dev/null @@ -1,23 +0,0 @@ -from pathlib import Path -from unittest.mock import Mock -import pytest - -from haystack.preview.testing.test_utils import set_all_seeds - -set_all_seeds(0) - - -@pytest.fixture() -def mock_tokenizer(): - """ - Tokenizes the string by splitting on spaces. - """ - tokenizer = Mock() - tokenizer.encode = lambda text: text.split() - tokenizer.decode = lambda tokens: " ".join(tokens) - return tokenizer - - -@pytest.fixture() -def test_files_path(): - return Path(__file__).parent / "test_files" diff --git a/test/preview/dataclasses/test_byte_stream.py b/test/preview/dataclasses/test_byte_stream.py deleted file mode 100644 index 4c3f0e154c..0000000000 --- a/test/preview/dataclasses/test_byte_stream.py +++ /dev/null @@ -1,43 +0,0 @@ -import io - -from haystack.preview.dataclasses import ByteStream - -import pytest - - -@pytest.mark.unit -def test_from_file_path(tmp_path, request): - test_bytes = "Hello, world!\n".encode() - test_path = tmp_path / request.node.name - with open(test_path, "wb") as fd: - assert fd.write(test_bytes) - - b = ByteStream.from_file_path(test_path) - assert b.data == test_bytes - assert b.mime_type == None - - b = ByteStream.from_file_path(test_path, mime_type="text/plain") - assert b.data == test_bytes - assert b.mime_type == "text/plain" - - -@pytest.mark.unit -def test_from_string(): - test_string = "Hello, world!" - b = ByteStream.from_string(test_string) - assert b.data.decode() == test_string - assert b.mime_type == None - - b = ByteStream.from_string(test_string, mime_type="text/plain") - assert b.data.decode() == test_string - assert b.mime_type == "text/plain" - - -@pytest.mark.unit -def test_to_file(tmp_path, request): - test_str = "Hello, world!\n" - test_path = tmp_path / request.node.name - - ByteStream(test_str.encode()).to_file(test_path) - with open(test_path, "rb") as fd: - assert fd.read().decode() == test_str diff --git a/test/preview/dataclasses/test_chat_message.py b/test/preview/dataclasses/test_chat_message.py deleted file mode 100644 index 1c0ec71cf3..0000000000 --- a/test/preview/dataclasses/test_chat_message.py +++ /dev/null @@ -1,70 +0,0 @@ -import pytest -from transformers import AutoTokenizer - -from haystack.preview.dataclasses import ChatMessage, ChatRole - - -@pytest.mark.unit -def test_from_assistant_with_valid_content(): - content = "Hello, how can I assist you?" - message = ChatMessage.from_assistant(content) - assert message.content == content - assert message.role == ChatRole.ASSISTANT - - -@pytest.mark.unit -def test_from_user_with_valid_content(): - content = "I have a question." - message = ChatMessage.from_user(content) - assert message.content == content - assert message.role == ChatRole.USER - - -@pytest.mark.unit -def test_from_system_with_valid_content(): - content = "System message." - message = ChatMessage.from_system(content) - assert message.content == content - assert message.role == ChatRole.SYSTEM - - -@pytest.mark.unit -def test_with_empty_content(): - message = ChatMessage.from_user("") - assert message.content == "" - - -@pytest.mark.unit -def test_from_function_with_empty_name(): - content = "Function call" - message = ChatMessage.from_function(content, "") - assert message.content == content - assert message.name == "" - - -@pytest.mark.integration -def test_apply_chat_templating_on_chat_message(): - messages = [ChatMessage.from_system("You are good assistant"), ChatMessage.from_user("I have a question")] - tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta") - tokenized_messages = tokenizer.apply_chat_template(messages, tokenize=False) - assert tokenized_messages == "<|system|>\nYou are good assistant\n<|user|>\nI have a question\n" - - -@pytest.mark.integration -def test_apply_custom_chat_templating_on_chat_message(): - anthropic_template = ( - "{%- for message in messages %}" - "{%- if message.role == 'user' %}\n\nHuman: {{ message.content.strip() }}" - "{%- elif message.role == 'assistant' %}\n\nAssistant: {{ message.content.strip() }}" - "{%- elif message.role == 'function' %}{{ raise('anthropic does not support function calls.') }}" - "{%- elif message.role == 'system' and loop.index == 1 %}{{ message.content }}" - "{%- else %}{{ raise('Invalid message role: ' + message.role) }}" - "{%- endif %}" - "{%- endfor %}" - "\n\nAssistant:" - ) - messages = [ChatMessage.from_system("You are good assistant"), ChatMessage.from_user("I have a question")] - # could be any tokenizer, let's use the one we already likely have in cache - tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta") - tokenized_messages = tokenizer.apply_chat_template(messages, chat_template=anthropic_template, tokenize=False) - assert tokenized_messages == "You are good assistant\nHuman: I have a question\nAssistant:" diff --git a/test/preview/dataclasses/test_document.py b/test/preview/dataclasses/test_document.py deleted file mode 100644 index 593a3f449e..0000000000 --- a/test/preview/dataclasses/test_document.py +++ /dev/null @@ -1,309 +0,0 @@ -from pathlib import Path - -import pandas as pd -import pytest - -from haystack.preview import Document -from haystack.preview.dataclasses.byte_stream import ByteStream - - -@pytest.mark.unit -@pytest.mark.parametrize( - "doc,doc_str", - [ - (Document(content="test text"), "content: 'test text'"), - ( - Document(dataframe=pd.DataFrame([["John", 25], ["Martha", 34]], columns=["name", "age"])), - "dataframe: (2, 2)", - ), - (Document(blob=ByteStream(b"hello, test string")), "blob: 18 bytes"), - ( - Document( - content="test text", - dataframe=pd.DataFrame([["John", 25], ["Martha", 34]], columns=["name", "age"]), - blob=ByteStream(b"hello, test string"), - ), - "content: 'test text', dataframe: (2, 2), blob: 18 bytes", - ), - ], -) -def test_document_str(doc, doc_str): - assert f"Document(id={doc.id}, {doc_str})" == str(doc) - - -@pytest.mark.unit -def test_init(): - doc = Document() - assert doc.id == "d4675c57fcfe114db0b95f1da46eea3c5d6f5729c17d01fb5251ae19830a3455" - assert doc.content == None - assert doc.dataframe == None - assert doc.blob == None - assert doc.meta == {} - assert doc.score == None - assert doc.embedding == None - - -@pytest.mark.unit -def test_init_with_wrong_parameters(): - with pytest.raises(TypeError): - Document(text="") - - -@pytest.mark.unit -def test_init_with_parameters(): - blob_data = b"some bytes" - doc = Document( - content="test text", - dataframe=pd.DataFrame([0]), - blob=ByteStream(data=blob_data, mime_type="text/markdown"), - meta={"text": "test text"}, - score=0.812, - embedding=[0.1, 0.2, 0.3], - ) - assert doc.id == "ec92455f3f4576d40031163c89b1b4210b34ea1426ee0ff68ebed86cb7ba13f8" - assert doc.content == "test text" - assert doc.dataframe is not None - assert doc.dataframe.equals(pd.DataFrame([0])) - assert doc.blob.data == blob_data - assert doc.blob.mime_type == "text/markdown" - assert doc.meta == {"text": "test text"} - assert doc.score == 0.812 - assert doc.embedding == [0.1, 0.2, 0.3] - - -@pytest.mark.unit -def test_init_with_legacy_fields(): - doc = Document( - content="test text", content_type="text", id_hash_keys=["content"], score=0.812, embedding=[0.1, 0.2, 0.3] # type: ignore - ) - assert doc.id == "18fc2c114825872321cf5009827ca162f54d3be50ab9e9ffa027824b6ec223af" - assert doc.content == "test text" - assert doc.dataframe == None - assert doc.blob == None - assert doc.meta == {} - assert doc.score == 0.812 - assert doc.embedding == [0.1, 0.2, 0.3] - - -@pytest.mark.unit -def test_init_with_legacy_field(): - doc = Document( - content="test text", - content_type="text", # type: ignore - id_hash_keys=["content"], # type: ignore - score=0.812, - embedding=[0.1, 0.2, 0.3], - meta={"date": "10-10-2023", "type": "article"}, - ) - assert doc.id == "a2c0321b34430cc675294611e55529fceb56140ca3202f1c59a43a8cecac1f43" - assert doc.content == "test text" - assert doc.dataframe == None - assert doc.meta == {"date": "10-10-2023", "type": "article"} - assert doc.score == 0.812 - assert doc.embedding == [0.1, 0.2, 0.3] - - -@pytest.mark.unit -def test_basic_equality_type_mismatch(): - doc = Document(content="test text") - assert doc != "test text" - - -@pytest.mark.unit -def test_basic_equality_id(): - doc1 = Document(content="test text") - doc2 = Document(content="test text") - - assert doc1 == doc2 - - doc1.id = "1234" - doc2.id = "5678" - - assert doc1 != doc2 - - -@pytest.mark.unit -def test_to_dict(): - doc = Document() - assert doc.to_dict() == { - "id": doc._create_id(), - "content": None, - "dataframe": None, - "blob": None, - "score": None, - "embedding": None, - } - - -@pytest.mark.unit -def test_to_dict_without_flattening(): - doc = Document() - assert doc.to_dict(flatten=False) == { - "id": doc._create_id(), - "content": None, - "dataframe": None, - "blob": None, - "meta": {}, - "score": None, - "embedding": None, - } - - -@pytest.mark.unit -def test_to_dict_with_custom_parameters(): - doc = Document( - content="test text", - dataframe=pd.DataFrame([10, 20, 30]), - blob=ByteStream(b"some bytes", mime_type="application/pdf"), - meta={"some": "values", "test": 10}, - score=0.99, - embedding=[10.0, 10.0], - ) - - assert doc.to_dict() == { - "id": doc.id, - "content": "test text", - "dataframe": pd.DataFrame([10, 20, 30]).to_json(), - "blob": {"data": list(b"some bytes"), "mime_type": "application/pdf"}, - "some": "values", - "test": 10, - "score": 0.99, - "embedding": [10.0, 10.0], - } - - -@pytest.mark.unit -def test_to_dict_with_custom_parameters_without_flattening(): - doc = Document( - content="test text", - dataframe=pd.DataFrame([10, 20, 30]), - blob=ByteStream(b"some bytes", mime_type="application/pdf"), - meta={"some": "values", "test": 10}, - score=0.99, - embedding=[10.0, 10.0], - ) - - assert doc.to_dict(flatten=False) == { - "id": doc.id, - "content": "test text", - "dataframe": pd.DataFrame([10, 20, 30]).to_json(), - "blob": {"data": list(b"some bytes"), "mime_type": "application/pdf"}, - "meta": {"some": "values", "test": 10}, - "score": 0.99, - "embedding": [10, 10], - } - - -@pytest.mark.unit -def test_from_dict(): - assert Document.from_dict({}) == Document() - - -@pytest.mark.unit -def from_from_dict_with_parameters(): - blob_data = b"some bytes" - assert Document.from_dict( - { - "content": "test text", - "dataframe": pd.DataFrame([0]).to_json(), - "blob": {"data": list(blob_data), "mime_type": "text/markdown"}, - "meta": {"text": "test text"}, - "score": 0.812, - "embedding": [0.1, 0.2, 0.3], - } - ) == Document( - content="test text", - dataframe=pd.DataFrame([0]), - blob=ByteStream(blob_data, mime_type="text/markdown"), - meta={"text": "test text"}, - score=0.812, - embedding=[0.1, 0.2, 0.3], - ) - - -@pytest.mark.unit -def test_from_dict_with_legacy_fields(): - assert Document.from_dict( - { - "content": "test text", - "content_type": "text", - "id_hash_keys": ["content"], - "score": 0.812, - "embedding": [0.1, 0.2, 0.3], - } - ) == Document( - content="test text", content_type="text", id_hash_keys=["content"], score=0.812, embedding=[0.1, 0.2, 0.3] # type: ignore - ) - - -def test_from_dict_with_legacy_field_and_flat_meta(): - assert Document.from_dict( - { - "content": "test text", - "content_type": "text", - "id_hash_keys": ["content"], - "score": 0.812, - "embedding": [0.1, 0.2, 0.3], - "date": "10-10-2023", - "type": "article", - } - ) == Document( - content="test text", - content_type="text", # type: ignore - id_hash_keys=["content"], # type: ignore - score=0.812, - embedding=[0.1, 0.2, 0.3], - meta={"date": "10-10-2023", "type": "article"}, - ) - - -@pytest.mark.unit -def test_from_dict_with_flat_meta(): - blob_data = b"some bytes" - assert Document.from_dict( - { - "content": "test text", - "dataframe": pd.DataFrame([0]).to_json(), - "blob": {"data": list(blob_data), "mime_type": "text/markdown"}, - "score": 0.812, - "embedding": [0.1, 0.2, 0.3], - "date": "10-10-2023", - "type": "article", - } - ) == Document( - content="test text", - dataframe=pd.DataFrame([0]), - blob=ByteStream(blob_data, mime_type="text/markdown"), - score=0.812, - embedding=[0.1, 0.2, 0.3], - meta={"date": "10-10-2023", "type": "article"}, - ) - - -@pytest.mark.unit -def test_from_dict_with_flat_and_non_flat_meta(): - with pytest.raises(ValueError, match="Pass either the 'meta' parameter or flattened metadata keys"): - Document.from_dict( - { - "content": "test text", - "dataframe": pd.DataFrame([0]).to_json(), - "blob": {"data": list(b"some bytes"), "mime_type": "text/markdown"}, - "score": 0.812, - "meta": {"test": 10}, - "embedding": [0.1, 0.2, 0.3], - "date": "10-10-2023", - "type": "article", - } - ) - - -@pytest.mark.unit -def test_content_type(): - assert Document(content="text").content_type == "text" - assert Document(dataframe=pd.DataFrame([0])).content_type == "table" - - with pytest.raises(ValueError): - _ = Document().content_type - - with pytest.raises(ValueError): - _ = Document(content="text", dataframe=pd.DataFrame([0])).content_type diff --git a/test/preview/dataclasses/test_streaming_chunk.py b/test/preview/dataclasses/test_streaming_chunk.py deleted file mode 100644 index 1a4c99ccc3..0000000000 --- a/test/preview/dataclasses/test_streaming_chunk.py +++ /dev/null @@ -1,32 +0,0 @@ -import pytest - -from haystack.preview.dataclasses import StreamingChunk - - -@pytest.mark.unit -def test_create_chunk_with_content_and_metadata(): - chunk = StreamingChunk(content="Test content", metadata={"key": "value"}) - - assert chunk.content == "Test content" - assert chunk.metadata == {"key": "value"} - - -@pytest.mark.unit -def test_create_chunk_with_only_content(): - chunk = StreamingChunk(content="Test content") - - assert chunk.content == "Test content" - assert chunk.metadata == {} - - -@pytest.mark.unit -def test_access_content(): - chunk = StreamingChunk(content="Test content", metadata={"key": "value"}) - assert chunk.content == "Test content" - - -@pytest.mark.unit -def test_create_chunk_with_empty_content(): - chunk = StreamingChunk(content="") - assert chunk.content == "" - assert chunk.metadata == {} diff --git a/test/preview/document_stores/test_in_memory.py b/test/preview/document_stores/test_in_memory.py deleted file mode 100644 index ce37754338..0000000000 --- a/test/preview/document_stores/test_in_memory.py +++ /dev/null @@ -1,399 +0,0 @@ -import logging -from unittest.mock import patch - -import pandas as pd -import pytest - -from haystack.preview import Document -from haystack.preview.document_stores import InMemoryDocumentStore, DocumentStoreError, DuplicatePolicy - - -from haystack.preview.testing.document_store import DocumentStoreBaseTests - - -class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904 - """ - Test InMemoryDocumentStore's specific features - """ - - @pytest.fixture - def document_store(self) -> InMemoryDocumentStore: - return InMemoryDocumentStore() - - @pytest.mark.unit - def test_to_dict(self): - store = InMemoryDocumentStore() - data = store.to_dict() - assert data == { - "type": "haystack.preview.document_stores.in_memory.document_store.InMemoryDocumentStore", - "init_parameters": { - "bm25_tokenization_regex": r"(?u)\b\w\w+\b", - "bm25_algorithm": "BM25Okapi", - "bm25_parameters": {}, - "embedding_similarity_function": "dot_product", - }, - } - - @pytest.mark.unit - def test_to_dict_with_custom_init_parameters(self): - store = InMemoryDocumentStore( - bm25_tokenization_regex="custom_regex", - bm25_algorithm="BM25Plus", - bm25_parameters={"key": "value"}, - embedding_similarity_function="cosine", - ) - data = store.to_dict() - assert data == { - "type": "haystack.preview.document_stores.in_memory.document_store.InMemoryDocumentStore", - "init_parameters": { - "bm25_tokenization_regex": "custom_regex", - "bm25_algorithm": "BM25Plus", - "bm25_parameters": {"key": "value"}, - "embedding_similarity_function": "cosine", - }, - } - - @pytest.mark.unit - @patch("haystack.preview.document_stores.in_memory.document_store.re") - def test_from_dict(self, mock_regex): - data = { - "type": "haystack.preview.document_stores.in_memory.document_store.InMemoryDocumentStore", - "init_parameters": { - "bm25_tokenization_regex": "custom_regex", - "bm25_algorithm": "BM25Plus", - "bm25_parameters": {"key": "value"}, - }, - } - store = InMemoryDocumentStore.from_dict(data) - mock_regex.compile.assert_called_with("custom_regex") - assert store.tokenizer - assert store.bm25_algorithm.__name__ == "BM25Plus" - assert store.bm25_parameters == {"key": "value"} - - @pytest.mark.unit - def test_written_documents_count(self, document_store: InMemoryDocumentStore): - # FIXME Remove after the document store base tests have been rewritten - documents = [Document(content=f"Hello world #{i}") for i in range(10)] - docs_written = document_store.write_documents(documents[0:2]) - assert docs_written == 2 - assert document_store.filter_documents() == documents[0:2] - - docs_written = document_store.write_documents(documents, DuplicatePolicy.SKIP) - assert docs_written == len(documents) - 2 - assert document_store.filter_documents() == documents - - @pytest.mark.unit - def test_bm25_retrieval(self, document_store: InMemoryDocumentStore): - document_store = InMemoryDocumentStore() - # Tests if the bm25_retrieval method returns the correct document based on the input query. - docs = [Document(content="Hello world"), Document(content="Haystack supports multiple languages")] - document_store.write_documents(docs) - results = document_store.bm25_retrieval(query="What languages?", top_k=1) - assert len(results) == 1 - assert results[0].content == "Haystack supports multiple languages" - - @pytest.mark.unit - def test_bm25_retrieval_with_empty_document_store(self, document_store: InMemoryDocumentStore, caplog): - caplog.set_level(logging.INFO) - # Tests if the bm25_retrieval method correctly returns an empty list when there are no documents in the DocumentStore. - results = document_store.bm25_retrieval(query="How to test this?", top_k=2) - assert len(results) == 0 - assert "No documents found for BM25 retrieval. Returning empty list." in caplog.text - - @pytest.mark.unit - def test_bm25_retrieval_empty_query(self, document_store: InMemoryDocumentStore): - # Tests if the bm25_retrieval method returns a document when the query is an empty string. - docs = [Document(content="Hello world"), Document(content="Haystack supports multiple languages")] - document_store.write_documents(docs) - with pytest.raises(ValueError, match="Query should be a non-empty string"): - document_store.bm25_retrieval(query="", top_k=1) - - @pytest.mark.unit - def test_bm25_retrieval_with_different_top_k(self, document_store: InMemoryDocumentStore): - # Tests if the bm25_retrieval method correctly changes the number of returned documents - # based on the top_k parameter. - docs = [ - Document(content="Hello world"), - Document(content="Haystack supports multiple languages"), - Document(content="Python is a popular programming language"), - ] - document_store.write_documents(docs) - - # top_k = 2 - results = document_store.bm25_retrieval(query="languages", top_k=2) - assert len(results) == 2 - - # top_k = 3 - results = document_store.bm25_retrieval(query="languages", top_k=3) - assert len(results) == 3 - - # Test two queries and make sure the results are different - @pytest.mark.unit - def test_bm25_retrieval_with_two_queries(self, document_store: InMemoryDocumentStore): - # Tests if the bm25_retrieval method returns different documents for different queries. - docs = [ - Document(content="Javascript is a popular programming language"), - Document(content="Java is a popular programming language"), - Document(content="Python is a popular programming language"), - Document(content="Ruby is a popular programming language"), - Document(content="PHP is a popular programming language"), - ] - document_store.write_documents(docs) - - results = document_store.bm25_retrieval(query="Java", top_k=1) - assert results[0].content == "Java is a popular programming language" - - results = document_store.bm25_retrieval(query="Python", top_k=1) - assert results[0].content == "Python is a popular programming language" - - # Test a query, add a new document and make sure results are appropriately updated - @pytest.mark.unit - def test_bm25_retrieval_with_updated_docs(self, document_store: InMemoryDocumentStore): - # Tests if the bm25_retrieval method correctly updates the retrieved documents when new - # documents are added to the DocumentStore. - docs = [Document(content="Hello world")] - document_store.write_documents(docs) - - results = document_store.bm25_retrieval(query="Python", top_k=1) - assert len(results) == 1 - - document_store.write_documents([Document(content="Python is a popular programming language")]) - results = document_store.bm25_retrieval(query="Python", top_k=1) - assert len(results) == 1 - assert results[0].content == "Python is a popular programming language" - - document_store.write_documents([Document(content="Java is a popular programming language")]) - results = document_store.bm25_retrieval(query="Python", top_k=1) - assert len(results) == 1 - assert results[0].content == "Python is a popular programming language" - - @pytest.mark.unit - def test_bm25_retrieval_with_scale_score(self, document_store: InMemoryDocumentStore): - docs = [Document(content="Python programming"), Document(content="Java programming")] - document_store.write_documents(docs) - - results1 = document_store.bm25_retrieval(query="Python", top_k=1, scale_score=True) - # Confirm that score is scaled between 0 and 1 - assert results1[0].score is not None - assert 0.0 <= results1[0].score <= 1.0 - - # Same query, different scale, scores differ when not scaled - results = document_store.bm25_retrieval(query="Python", top_k=1, scale_score=False) - assert results[0].score != results1[0].score - - @pytest.mark.unit - def test_bm25_retrieval_with_table_content(self, document_store: InMemoryDocumentStore): - # Tests if the bm25_retrieval method correctly returns a dataframe when the content_type is table. - table_content = pd.DataFrame({"language": ["Python", "Java"], "use": ["Data Science", "Web Development"]}) - docs = [Document(dataframe=table_content), Document(content="Gardening"), Document(content="Bird watching")] - document_store.write_documents(docs) - results = document_store.bm25_retrieval(query="Java", top_k=1) - assert len(results) == 1 - - df = results[0].dataframe - assert isinstance(df, pd.DataFrame) - assert df.equals(table_content) - - @pytest.mark.unit - def test_bm25_retrieval_with_text_and_table_content(self, document_store: InMemoryDocumentStore, caplog): - table_content = pd.DataFrame({"language": ["Python", "Java"], "use": ["Data Science", "Web Development"]}) - document = Document(content="Gardening", dataframe=table_content) - docs = [ - document, - Document(content="Python"), - Document(content="Bird Watching"), - Document(content="Gardening"), - Document(content="Java"), - ] - document_store.write_documents(docs) - results = document_store.bm25_retrieval(query="Gardening", top_k=2) - assert document.id in [d.id for d in results] - assert "both text and dataframe content" in caplog.text - results = document_store.bm25_retrieval(query="Python", top_k=2) - assert document.id not in [d.id for d in results] - - @pytest.mark.unit - def test_bm25_retrieval_default_filter_for_text_and_dataframes(self, document_store: InMemoryDocumentStore): - docs = [Document(), Document(content="Gardening"), Document(content="Bird watching")] - document_store.write_documents(docs) - results = document_store.bm25_retrieval(query="doesn't matter, top_k is 10", top_k=10) - assert len(results) == 2 - - @pytest.mark.unit - def test_bm25_retrieval_with_filters(self, document_store: InMemoryDocumentStore): - selected_document = Document(content="Gardening", meta={"selected": True}) - docs = [Document(), selected_document, Document(content="Bird watching")] - document_store.write_documents(docs) - results = document_store.bm25_retrieval(query="Java", top_k=10, filters={"selected": True}) - assert len(results) == 1 - assert results[0].id == selected_document.id - - @pytest.mark.unit - def test_bm25_retrieval_with_filters_keeps_default_filters(self, document_store: InMemoryDocumentStore): - docs = [Document(meta={"selected": True}), Document(content="Gardening"), Document(content="Bird watching")] - document_store.write_documents(docs) - results = document_store.bm25_retrieval(query="Java", top_k=10, filters={"selected": True}) - assert len(results) == 0 - - @pytest.mark.unit - def test_bm25_retrieval_with_filters_on_text_or_dataframe(self, document_store: InMemoryDocumentStore): - document = Document(dataframe=pd.DataFrame({"language": ["Python", "Java"], "use": ["Data Science", "Web"]})) - docs = [Document(), Document(content="Gardening"), Document(content="Bird watching"), document] - document_store.write_documents(docs) - results = document_store.bm25_retrieval(query="Java", top_k=10, filters={"content": None}) - assert len(results) == 1 - assert results[0].id == document.id - - @pytest.mark.unit - def test_bm25_retrieval_with_documents_with_mixed_content(self, document_store: InMemoryDocumentStore): - double_document = Document(content="Gardening", embedding=[1.0, 2.0, 3.0]) - docs = [Document(embedding=[1.0, 2.0, 3.0]), double_document, Document(content="Bird watching")] - document_store.write_documents(docs) - results = document_store.bm25_retrieval(query="Java", top_k=10, filters={"embedding": {"$not": None}}) - assert len(results) == 1 - assert results[0].id == double_document.id - - @pytest.mark.unit - def test_embedding_retrieval(self): - docstore = InMemoryDocumentStore(embedding_similarity_function="cosine") - # Tests if the embedding retrieval method returns the correct document based on the input query embedding. - docs = [ - Document(content="Hello world", embedding=[0.1, 0.2, 0.3, 0.4]), - Document(content="Haystack supports multiple languages", embedding=[1.0, 1.0, 1.0, 1.0]), - ] - docstore.write_documents(docs) - results = docstore.embedding_retrieval( - query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=1, filters={}, scale_score=False - ) - assert len(results) == 1 - assert results[0].content == "Haystack supports multiple languages" - - @pytest.mark.unit - def test_embedding_retrieval_invalid_query(self): - docstore = InMemoryDocumentStore() - with pytest.raises(ValueError, match="query_embedding should be a non-empty list of floats"): - docstore.embedding_retrieval(query_embedding=[]) - with pytest.raises(ValueError, match="query_embedding should be a non-empty list of floats"): - docstore.embedding_retrieval(query_embedding=["invalid", "list", "of", "strings"]) # type: ignore - - @pytest.mark.unit - def test_embedding_retrieval_no_embeddings(self, caplog): - caplog.set_level(logging.WARNING) - docstore = InMemoryDocumentStore() - docs = [Document(content="Hello world"), Document(content="Haystack supports multiple languages")] - docstore.write_documents(docs) - results = docstore.embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1]) - assert len(results) == 0 - assert "No Documents found with embeddings. Returning empty list." in caplog.text - - @pytest.mark.unit - def test_embedding_retrieval_some_documents_wo_embeddings(self, caplog): - caplog.set_level(logging.INFO) - docstore = InMemoryDocumentStore() - docs = [ - Document(content="Hello world", embedding=[0.1, 0.2, 0.3, 0.4]), - Document(content="Haystack supports multiple languages"), - ] - docstore.write_documents(docs) - docstore.embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1]) - assert "Skipping some Documents that don't have an embedding." in caplog.text - - @pytest.mark.unit - def test_embedding_retrieval_documents_different_embedding_sizes(self): - docstore = InMemoryDocumentStore() - docs = [ - Document(content="Hello world", embedding=[0.1, 0.2, 0.3, 0.4]), - Document(content="Haystack supports multiple languages", embedding=[1.0, 1.0]), - ] - docstore.write_documents(docs) - - with pytest.raises(DocumentStoreError, match="The embedding size of all Documents should be the same."): - docstore.embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1]) - - @pytest.mark.unit - def test_embedding_retrieval_query_documents_different_embedding_sizes(self): - docstore = InMemoryDocumentStore() - docs = [Document(content="Hello world", embedding=[0.1, 0.2, 0.3, 0.4])] - docstore.write_documents(docs) - - with pytest.raises( - DocumentStoreError, - match="The embedding size of the query should be the same as the embedding size of the Documents.", - ): - docstore.embedding_retrieval(query_embedding=[0.1, 0.1]) - - @pytest.mark.unit - def test_embedding_retrieval_with_different_top_k(self): - docstore = InMemoryDocumentStore() - docs = [ - Document(content="Hello world", embedding=[0.1, 0.2, 0.3, 0.4]), - Document(content="Haystack supports multiple languages", embedding=[1.0, 1.0, 1.0, 1.0]), - Document(content="Python is a popular programming language", embedding=[0.5, 0.5, 0.5, 0.5]), - ] - docstore.write_documents(docs) - - results = docstore.embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=2) - assert len(results) == 2 - - results = docstore.embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=3) - assert len(results) == 3 - - @pytest.mark.unit - def test_embedding_retrieval_with_scale_score(self): - docstore = InMemoryDocumentStore() - docs = [ - Document(content="Hello world", embedding=[0.1, 0.2, 0.3, 0.4]), - Document(content="Haystack supports multiple languages", embedding=[1.0, 1.0, 1.0, 1.0]), - Document(content="Python is a popular programming language", embedding=[0.5, 0.5, 0.5, 0.5]), - ] - docstore.write_documents(docs) - - results1 = docstore.embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=1, scale_score=True) - # Confirm that score is scaled between 0 and 1 - assert results1[0].score is not None - assert 0.0 <= results1[0].score <= 1.0 - - # Same query, different scale, scores differ when not scaled - results = docstore.embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=1, scale_score=False) - assert results[0].score != results1[0].score - - @pytest.mark.unit - def test_embedding_retrieval_return_embedding(self): - docstore = InMemoryDocumentStore(embedding_similarity_function="cosine") - docs = [ - Document(content="Hello world", embedding=[0.1, 0.2, 0.3, 0.4]), - Document(content="Haystack supports multiple languages", embedding=[1.0, 1.0, 1.0, 1.0]), - ] - docstore.write_documents(docs) - - results = docstore.embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=1, return_embedding=False) - assert results[0].embedding is None - - results = docstore.embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=1, return_embedding=True) - assert results[0].embedding == [1.0, 1.0, 1.0, 1.0] - - @pytest.mark.unit - def test_compute_cosine_similarity_scores(self): - docstore = InMemoryDocumentStore(embedding_similarity_function="cosine") - docs = [ - Document(content="Document 1", embedding=[1.0, 0.0, 0.0, 0.0]), - Document(content="Document 2", embedding=[1.0, 1.0, 1.0, 1.0]), - ] - - scores = docstore._compute_query_embedding_similarity_scores( - embedding=[0.1, 0.1, 0.1, 0.1], documents=docs, scale_score=False - ) - assert scores == [0.5, 1.0] - - @pytest.mark.unit - def test_compute_dot_product_similarity_scores(self): - docstore = InMemoryDocumentStore(embedding_similarity_function="dot_product") - docs = [ - Document(content="Document 1", embedding=[1.0, 0.0, 0.0, 0.0]), - Document(content="Document 2", embedding=[1.0, 1.0, 1.0, 1.0]), - ] - - scores = docstore._compute_query_embedding_similarity_scores( - embedding=[0.1, 0.1, 0.1, 0.1], documents=docs, scale_score=False - ) - assert scores == [0.1, 0.4] diff --git a/test/preview/test_files/audio/answer.wav b/test/preview/test_files/audio/answer.wav deleted file mode 100644 index 874eab01fe..0000000000 Binary files a/test/preview/test_files/audio/answer.wav and /dev/null differ diff --git a/test/preview/test_files/audio/the context for this answer is here.wav b/test/preview/test_files/audio/the context for this answer is here.wav deleted file mode 100644 index b6515a85db..0000000000 Binary files a/test/preview/test_files/audio/the context for this answer is here.wav and /dev/null differ diff --git a/test/preview/test_files/audio/this is the content of the document.wav b/test/preview/test_files/audio/this is the content of the document.wav deleted file mode 100644 index 37d651fa9d..0000000000 Binary files a/test/preview/test_files/audio/this is the content of the document.wav and /dev/null differ diff --git a/test/preview/test_files/docx/sample_docx.docx b/test/preview/test_files/docx/sample_docx.docx deleted file mode 100644 index 3a740ac968..0000000000 Binary files a/test/preview/test_files/docx/sample_docx.docx and /dev/null differ diff --git a/test/preview/test_files/html/what_is_haystack.html b/test/preview/test_files/html/what_is_haystack.html deleted file mode 100644 index 2d62b206c0..0000000000 --- a/test/preview/test_files/html/what_is_haystack.html +++ /dev/null @@ -1,1634 +0,0 @@ - - - - - - - - - - What is Haystack? | Haystack - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
- 🎃 We're participating in Hacktoberfest 2023! - - - - - -
-
- - - -
- -
-
-
- - - - - -
-

What is Haystack?

-

Haystack is the open source Python framework by deepset for building custom apps with large language models (LLMs). It lets you quickly try out the latest models in natural language processing (NLP) while being flexible and easy to use. Our inspiring community of users and builders has helped shape Haystack into what it is today: a complete framework for building production-ready NLP apps.

-

Building with Haystack

-

Haystack offers comprehensive tooling for developing state-of-the-art NLP systems that use LLMs (such as GPT-4, Falcon and similar) and Transformer models . With Haystack, you can effortlessly experiment with various models hosted on platforms like Hugging Face, OpenAI, Cohere, or even models deployed on SageMaker and your local models to find the perfect fit for your use case.

- - - - - - - - - - - - - - - - - - - - - Model Providers - - -

Some examples of what you can build include:

-
    -
  • Semantic search on a large collection of documents in any language
  • -
  • Generative question answering on a knowledge base containing mixed types of information: images, text, and tables.
  • -
  • Natural language chatbots powered by cutting-edge generative models like GPT-4
  • -
  • An LLM-based Haystack Agent capable of resolving complex queries
  • -
  • Information extraction from documents to populate your database or build a knowledge graph
  • -
-

This is just a small subset of the kinds of systems that can be created in Haystack.

-

Functionality for all stages of an NLP project

-

A successful NLP project requires more than just the language models. As an end-to-end framework, Haystack assists you in building your system every step of the way, offering tooling for each stage of the NLP project life cycle:

- -

But that’s not all: -metadata filtering, -model distillation, or the prompt hub, whatever your NLP heart desires, you’re likely to find it in Haystack. And if not? We’ll build it together.

- - - - - - - - - - - - - - - - - - - - - - - Rest API - - -

Building blocks

-

Haystack uses a few simple but effective concepts to help you build fully functional and customized end-to-end NLP systems.

-

Components

-

At the core of Haystack are its components—fundamental building blocks that can perform tasks like document retrieval, text generation, or summarization. A single component is already quite powerful. It can manage local language models or communicate with a hosted model through an API.

-

While Haystack offers a bunch of components you can use out of the box, it also lets you create your own custom components. Explore the -collection of integrations that includes custom components developed by our community, which you can freely use.

-

You can chain components together to build pipelines, which are the foundation of the NLP app architecture in Haystack.

-

Pipelines

-

Pipelines are powerful structures made up of components, such as a Retriever and Reader, connected to infrastructure building blocks, such as a DocumentStore (for example, Elasticsearch or Weaviate) to form complex systems.

-

Haystack offers ready-made pipelines for most common tasks, such as question answering, document retrieval, or summarization. But it’s just as easy to design and create a custom pipeline for NLP scenarios that are way more complex than question answering.

-

Agents

-

The Haystack Agent makes use of a large language model to resolve complex tasks. When initializing the Agent, you give it a set of tools, which can be pipeline components or whole pipelines. The Agent can use to those tools iteratively to arrive at an answer. When given a query, the Agent determines which tools are useful to answer this query and calls them in a loop until it gets the answer. This way, it can achieve much more than extractive or generative question answering pipelines.

- - - - - - - - - - - - - - - - - - - - - Agent Tools - - -

Who’s it for?

-

Haystack is for everyone looking to build natural language apps—NLP enthusiasts and newbies alike. You don’t need to understand how the models work under the hood. With Haystack’s modular and flexible components, pipelines, and agents, all you need is some basic knowledge of Python to dive right in.

-

Our community

-

At the heart of Haystack is the vibrant open source community that thrives on the diverse backgrounds and skill sets of its members. We value collaboration greatly and encourage our users to shape Haystack actively through GitHub contributions. Our Discord channel is a space where community members can connect, seek help, and learn from each other.

-

We also organize live online and in-person events, webinars, and office hours, which are an opportunity to learn and grow.

- - - - - - - - -
- - - -
- Join Discord -
- - - -
-
- -

Enter the Haystack universe

- - - - -
- - - -
- -
-
-
-
-
- - - - - - - - - - - - - - - -
-
-
- -
- - - -
- - - -
-
-
- - - - - - - - - - - - - - - - - - - - - - diff --git a/test/preview/test_files/images/apple.jpg b/test/preview/test_files/images/apple.jpg deleted file mode 100644 index f9023fea2c..0000000000 Binary files a/test/preview/test_files/images/apple.jpg and /dev/null differ diff --git a/test/preview/test_files/images/haystack-logo.png b/test/preview/test_files/images/haystack-logo.png deleted file mode 100644 index bf001d00a5..0000000000 Binary files a/test/preview/test_files/images/haystack-logo.png and /dev/null differ diff --git a/test/preview/test_files/markdown/sample.md b/test/preview/test_files/markdown/sample.md deleted file mode 100644 index d39e32d44e..0000000000 --- a/test/preview/test_files/markdown/sample.md +++ /dev/null @@ -1,65 +0,0 @@ ---- -type: intro -date: 1.1.2023 ---- -```bash -pip install farm-haystack -``` -## What to build with Haystack - -- **Ask questions in natural language** and find granular answers in your own documents. -- Perform **semantic search** and retrieve documents according to meaning not keywords -- Use **off-the-shelf models** or **fine-tune** them to your own domain. -- Use **user feedback** to evaluate, benchmark and continuously improve your live models. -- Leverage existing **knowledge bases** and better handle the long tail of queries that **chatbots** receive. -- **Automate processes** by automatically applying a list of questions to new documents and using the extracted answers. - -![Logo](https://raw.githubusercontent.com/deepset-ai/haystack/main/docs/img/logo.png) - - -## Core Features - -- **Latest models**: Utilize all latest transformer based models (e.g. BERT, RoBERTa, MiniLM) for extractive QA, generative QA and document retrieval. -- **Modular**: Multiple choices to fit your tech stack and use case. Pick your favorite database, file converter or modeling framework. -- **Open**: 100% compatible with HuggingFace's model hub. Tight interfaces to other frameworks (e.g. Transformers, FARM, sentence-transformers) -- **Scalable**: Scale to millions of docs via retrievers, production-ready backends like Elasticsearch / FAISS and a fastAPI REST API -- **End-to-End**: All tooling in one place: file conversion, cleaning, splitting, training, eval, inference, labeling ... -- **Developer friendly**: Easy to debug, extend and modify. -- **Customizable**: Fine-tune models to your own domain or implement your custom DocumentStore. -- **Continuous Learning**: Collect new training data via user feedback in production & improve your models continuously - -| | | -|-|-| -| :ledger: [Docs](https://haystack.deepset.ai/overview/intro) | Usage, Guides, API documentation ...| -| :beginner: [Quick Demo](https://github.com/deepset-ai/haystack/#quick-demo) | Quickly see what Haystack offers | -| :floppy_disk: [Installation](https://github.com/deepset-ai/haystack/#installation) | How to install Haystack | -| :art: [Key Components](https://github.com/deepset-ai/haystack/#key-components) | Overview of core concepts | -| :mortar_board: [Tutorials](https://github.com/deepset-ai/haystack/#tutorials) | Jupyter/Colab Notebooks & Scripts | -| :eyes: [How to use Haystack](https://github.com/deepset-ai/haystack/#how-to-use-haystack) | Basic explanation of concepts, options and usage | -| :heart: [Contributing](https://github.com/deepset-ai/haystack/#heart-contributing) | We welcome all contributions! | -| :bar_chart: [Benchmarks](https://haystack.deepset.ai/benchmarks/v0.9.0) | Speed & Accuracy of Retriever, Readers and DocumentStores | -| :telescope: [Roadmap](https://haystack.deepset.ai/overview/roadmap) | Public roadmap of Haystack | -| :pray: [Slack](https://haystack.deepset.ai/community/join) | Join our community on Slack | -| :bird: [Twitter](https://twitter.com/deepset_ai) | Follow us on Twitter for news and updates | -| :newspaper: [Blog](https://medium.com/deepset-ai) | Read our articles on Medium | - - -## Quick Demo - -The quickest way to see what Haystack offers is to start a [Docker Compose](https://docs.docker.com/compose/) demo application: - -**1. Update/install Docker and Docker Compose, then launch Docker** - -``` - # apt-get update && apt-get install docker && apt-get install docker-compose - # service docker start -``` - -**2. Clone Haystack repository** - -``` - # git clone https://github.com/deepset-ai/haystack.git -``` - -### 2nd level headline for testing purposes -#### 3rd level headline for testing purposes diff --git a/test/preview/test_files/pdf/react_paper.pdf b/test/preview/test_files/pdf/react_paper.pdf deleted file mode 100644 index ec0d1289e5..0000000000 Binary files a/test/preview/test_files/pdf/react_paper.pdf and /dev/null differ diff --git a/test/preview/test_files/pdf/sample_pdf_1.pdf b/test/preview/test_files/pdf/sample_pdf_1.pdf deleted file mode 100644 index 87259b897f..0000000000 Binary files a/test/preview/test_files/pdf/sample_pdf_1.pdf and /dev/null differ diff --git a/test/preview/test_files/pdf/sample_pdf_2.pdf b/test/preview/test_files/pdf/sample_pdf_2.pdf deleted file mode 100644 index 6384246e89..0000000000 Binary files a/test/preview/test_files/pdf/sample_pdf_2.pdf and /dev/null differ diff --git a/test/preview/test_files/txt/doc_1.txt b/test/preview/test_files/txt/doc_1.txt deleted file mode 100644 index 4121890801..0000000000 --- a/test/preview/test_files/txt/doc_1.txt +++ /dev/null @@ -1,2 +0,0 @@ -Some text for testing. -Two lines in here. diff --git a/test/preview/test_files/txt/doc_2.txt b/test/preview/test_files/txt/doc_2.txt deleted file mode 100644 index 6f950eedcf..0000000000 --- a/test/preview/test_files/txt/doc_2.txt +++ /dev/null @@ -1,3 +0,0 @@ -This is a test line. -123 456 789 -987 654 321. diff --git a/test/preview/test_files/txt/doc_3.txt b/test/preview/test_files/txt/doc_3.txt deleted file mode 100644 index a7c4fd5e55..0000000000 --- a/test/preview/test_files/txt/doc_3.txt +++ /dev/null @@ -1,11 +0,0 @@ -That's yet another file! - -it contains - - - - -many - - -empty lines. diff --git a/test/preview/test_files/yaml/test_pipeline.yaml b/test/preview/test_files/yaml/test_pipeline.yaml deleted file mode 100644 index 0690b86f1a..0000000000 --- a/test/preview/test_files/yaml/test_pipeline.yaml +++ /dev/null @@ -1,14 +0,0 @@ -components: - Comp1: - init_parameters: - an_init_param: null - type: test_pipeline.TestComponent - Comp2: - init_parameters: - an_init_param: null - type: test_pipeline.TestComponent -connections: -- receiver: Comp2.input_ - sender: Comp1.value -max_loops_allowed: 99 -metadata: {} diff --git a/test/preview/test_pipeline.py b/test/preview/test_pipeline.py deleted file mode 100644 index be60e4be89..0000000000 --- a/test/preview/test_pipeline.py +++ /dev/null @@ -1,62 +0,0 @@ -from typing import Optional - -import pytest - -from haystack.preview import Pipeline, component - - -@component -class TestComponent: - def __init__(self, an_init_param: Optional[str] = None): - pass - - @component.output_types(value=str) - def run(self, input_: str): - return {"value": input_} - - -@pytest.fixture -def pipeline(): - return Pipeline() - - -@pytest.mark.unit -def test_pipeline_dumps(pipeline, test_files_path): - pipeline.add_component("Comp1", TestComponent("Foo")) - pipeline.add_component("Comp2", TestComponent()) - pipeline.connect("Comp1.value", "Comp2.input_") - pipeline.max_loops_allowed = 99 - result = pipeline.dumps() - with open(f"{test_files_path}/yaml/test_pipeline.yaml", "r") as f: - assert f.read() == result - - -@pytest.mark.unit -def test_pipeline_loads(test_files_path): - with open(f"{test_files_path}/yaml/test_pipeline.yaml", "r") as f: - pipeline = Pipeline.loads(f.read()) - assert pipeline.max_loops_allowed == 99 - assert isinstance(pipeline.get_component("Comp1"), TestComponent) - assert isinstance(pipeline.get_component("Comp2"), TestComponent) - - -@pytest.mark.unit -def test_pipeline_dump(pipeline, test_files_path, tmp_path): - pipeline.add_component("Comp1", TestComponent("Foo")) - pipeline.add_component("Comp2", TestComponent()) - pipeline.connect("Comp1.value", "Comp2.input_") - pipeline.max_loops_allowed = 99 - with open(tmp_path / "out.yaml", "w") as f: - pipeline.dump(f) - # re-open and ensure it's the same data as the test file - with open(f"{test_files_path}/yaml/test_pipeline.yaml", "r") as test_f, open(tmp_path / "out.yaml", "r") as f: - assert f.read() == test_f.read() - - -@pytest.mark.unit -def test_pipeline_load(test_files_path): - with open(f"{test_files_path}/yaml/test_pipeline.yaml", "r") as f: - pipeline = Pipeline.load(f) - assert pipeline.max_loops_allowed == 99 - assert isinstance(pipeline.get_component("Comp1"), TestComponent) - assert isinstance(pipeline.get_component("Comp2"), TestComponent) diff --git a/test/preview/test_telemetry.py b/test/preview/test_telemetry.py deleted file mode 100644 index ba9105eee7..0000000000 --- a/test/preview/test_telemetry.py +++ /dev/null @@ -1,54 +0,0 @@ -import datetime -from unittest.mock import Mock, patch -import pytest - -from haystack.preview import Pipeline, component -from haystack.preview.telemetry._telemetry import pipeline_running - - -@pytest.mark.unit -@patch("haystack.preview.telemetry._telemetry.telemetry") -def test_pipeline_running(telemetry): - telemetry.send_event = Mock() - - @component - class Component: - def _get_telemetry_data(self): - return {"key": "values"} - - @component.output_types(value=int) - def run(self): - pass - - pipe = Pipeline() - pipe.add_component("component", Component()) - pipeline_running(pipe) - - # First run is always sent - telemetry.send_event.assert_called_once_with( - "Pipeline run (2.x)", - { - "pipeline_id": str(id(pipe)), - "runs": 1, - "components": {"test_telemetry.Component": [{"name": "component", "key": "values"}]}, - }, - ) - - # Running again before one minute has passed should not send another event - telemetry.send_event.reset_mock() - pipeline_running(pipe) - telemetry.send_event.assert_not_called() - - # Set the last telemetry sent time to pretend one minute has passed - pipe._last_telemetry_sent = pipe._last_telemetry_sent - datetime.timedelta(minutes=1) - - telemetry.send_event.reset_mock() - pipeline_running(pipe) - telemetry.send_event.assert_called_once_with( - "Pipeline run (2.x)", - { - "pipeline_id": str(id(pipe)), - "runs": 3, - "components": {"test_telemetry.Component": [{"name": "component", "key": "values"}]}, - }, - ) diff --git a/test/preview/testing/test_factory.py b/test/preview/testing/test_factory.py deleted file mode 100644 index 5a7fdc78a8..0000000000 --- a/test/preview/testing/test_factory.py +++ /dev/null @@ -1,70 +0,0 @@ -import pytest - -from haystack.preview.dataclasses import Document -from haystack.preview.testing.factory import document_store_class -from haystack.preview.document_stores.decorator import document_store - - -@pytest.mark.unit -def test_document_store_class_default(): - MyStore = document_store_class("MyStore") - store = MyStore() - assert store.count_documents() == 0 - assert store.filter_documents() == [] - assert store.write_documents([]) is None - assert store.delete_documents([]) is None - assert store.to_dict() == {"type": "haystack.preview.testing.factory.MyStore", "init_parameters": {}} - - -@pytest.mark.unit -def test_document_store_from_dict(): - MyStore = document_store_class("MyStore") - - store = MyStore.from_dict({"type": "haystack.preview.testing.factory.MyStore", "init_parameters": {}}) - assert isinstance(store, MyStore) - - -@pytest.mark.unit -def test_document_store_class_is_registered(): - MyStore = document_store_class("MyStore") - assert document_store.registry["haystack.preview.testing.factory.MyStore"] == MyStore - - -@pytest.mark.unit -def test_document_store_class_with_documents(): - doc = Document(id="fake_id", content="This is a document") - MyStore = document_store_class("MyStore", documents=[doc]) - store = MyStore() - assert store.count_documents() == 1 - assert store.filter_documents() == [doc] - - -@pytest.mark.unit -def test_document_store_class_with_documents_count(): - MyStore = document_store_class("MyStore", documents_count=100) - store = MyStore() - assert store.count_documents() == 100 - assert store.filter_documents() == [] - - -@pytest.mark.unit -def test_document_store_class_with_documents_and_documents_count(): - doc = Document(id="fake_id", content="This is a document") - MyStore = document_store_class("MyStore", documents=[doc], documents_count=100) - store = MyStore() - assert store.count_documents() == 100 - assert store.filter_documents() == [doc] - - -@pytest.mark.unit -def test_document_store_class_with_bases(): - MyStore = document_store_class("MyStore", bases=(Exception,)) - store = MyStore() - assert isinstance(store, Exception) - - -@pytest.mark.unit -def test_document_store_class_with_extra_fields(): - MyStore = document_store_class("MyStore", extra_fields={"my_field": 10}) - store = MyStore() - assert store.my_field == 10 diff --git a/test/preview/utils/test_filters.py b/test/preview/utils/test_filters.py deleted file mode 100644 index 1b3baaf771..0000000000 --- a/test/preview/utils/test_filters.py +++ /dev/null @@ -1,725 +0,0 @@ -import pytest -import pandas as pd - -from haystack.preview import Document -from haystack.preview.errors import FilterError -from haystack.preview.utils.filters import convert, document_matches_filter - -document_matches_filter_data = [ - # == operator params - pytest.param( - {"field": "meta.name", "operator": "==", "value": "test"}, - Document(meta={"name": "test"}), - True, - id="== operator with equal values", - ), - pytest.param( - {"field": "meta.name", "operator": "==", "value": "test"}, - Document(meta={"name": "different value"}), - False, - id="== operator with different values", - ), - pytest.param( - {"field": "meta.name", "operator": "==", "value": "test"}, - Document(meta={"name": ["test"]}), - False, - id="== operator with different types values", - ), - pytest.param( - {"field": "dataframe", "operator": "==", "value": pd.DataFrame([1])}, - Document(dataframe=pd.DataFrame([1])), - True, - id="== operator with equal pandas.DataFrame values", - ), - pytest.param( - {"field": "dataframe", "operator": "==", "value": pd.DataFrame([1])}, - Document(dataframe=pd.DataFrame([10])), - False, - id="== operator with different pandas.DataFrame values", - ), - pytest.param( - {"field": "meta.name", "operator": "==", "value": "test"}, - Document(), - False, - id="== operator with missing Document value", - ), - pytest.param( - {"field": "meta.name", "operator": "==", "value": "test"}, - Document(meta={"name": None}), - False, - id="== operator with None Document value", - ), - pytest.param( - {"field": "meta.name", "operator": "==", "value": None}, - Document(meta={"name": "test"}), - False, - id="== operator with None filter value", - ), - # != operator params - pytest.param( - {"field": "meta.name", "operator": "!=", "value": "test"}, - Document(meta={"name": "test"}), - False, - id="!= operator with equal values", - ), - pytest.param( - {"field": "meta.name", "operator": "!=", "value": "test"}, - Document(meta={"name": "different value"}), - True, - id="!= operator with different values", - ), - pytest.param( - {"field": "meta.name", "operator": "!=", "value": "test"}, - Document(meta={"name": ["test"]}), - True, - id="!= operator with different types values", - ), - pytest.param( - {"field": "dataframe", "operator": "!=", "value": pd.DataFrame([1])}, - Document(dataframe=pd.DataFrame([1])), - False, - id="!= operator with equal pandas.DataFrame values", - ), - pytest.param( - {"field": "dataframe", "operator": "!=", "value": pd.DataFrame([1])}, - Document(dataframe=pd.DataFrame([10])), - True, - id="!= operator with different pandas.DataFrame values", - ), - pytest.param( - {"field": "meta.name", "operator": "!=", "value": "test"}, Document(), True, id="!= operator with missing value" - ), - pytest.param( - {"field": "meta.name", "operator": "!=", "value": "test"}, - Document(meta={"name": None}), - True, - id="!= operator with None Document value", - ), - pytest.param( - {"field": "meta.name", "operator": "!=", "value": None}, - Document(meta={"name": "test"}), - True, - id="!= operator with None filter value", - ), - # > operator params - pytest.param( - {"field": "meta.page", "operator": ">", "value": 10}, - Document(meta={"page": 10}), - False, - id="> operator with equal Document value", - ), - pytest.param( - {"field": "meta.page", "operator": ">", "value": 10}, - Document(meta={"page": 11}), - True, - id="> operator with greater Document value", - ), - pytest.param( - {"field": "meta.page", "operator": ">", "value": 10}, - Document(meta={"page": 9}), - False, - id="> operator with smaller Document value", - ), - pytest.param( - {"field": "meta.date", "operator": ">", "value": "1969-07-21T20:17:40"}, - Document(meta={"date": "1969-07-21T20:17:40"}), - False, - id="> operator with equal ISO 8601 datetime Document value", - ), - pytest.param( - {"field": "meta.date", "operator": ">", "value": "1969-07-21T20:17:40"}, - Document(meta={"date": "1972-12-11T19:54:58"}), - True, - id="> operator with greater ISO 8601 datetime Document value", - ), - pytest.param( - {"field": "meta.date", "operator": ">", "value": "1972-12-11T19:54:58"}, - Document(meta={"date": "1969-07-21T20:17:40"}), - False, - id="> operator with smaller ISO 8601 datetime Document value", - ), - pytest.param( - {"field": "meta.page", "operator": ">", "value": 10}, - Document(), - False, - id="> operator with missing Document value", - ), - pytest.param( - {"field": "meta.page", "operator": ">", "value": 10}, - Document(meta={"page": None}), - False, - id="> operator with None Document value", - ), - pytest.param( - {"field": "meta.page", "operator": ">", "value": None}, - Document(meta={"page": 10}), - False, - id="> operator with None filter value", - ), - pytest.param( - {"field": "meta.page", "operator": ">", "value": None}, - Document(meta={"page": None}), - False, - id="> operator with None Document and filter value", - ), - # >= operator params - pytest.param( - {"field": "meta.page", "operator": ">=", "value": 10}, - Document(meta={"page": 10}), - True, - id=">= operator with equal Document value", - ), - pytest.param( - {"field": "meta.page", "operator": ">=", "value": 10}, - Document(meta={"page": 11}), - True, - id=">= operator with greater Document value", - ), - pytest.param( - {"field": "meta.page", "operator": ">=", "value": 10}, - Document(meta={"page": 9}), - False, - id=">= operator with smaller Document value", - ), - pytest.param( - {"field": "meta.date", "operator": ">=", "value": "1969-07-21T20:17:40"}, - Document(meta={"date": "1969-07-21T20:17:40"}), - True, - id=">= operator with equal ISO 8601 datetime Document value", - ), - pytest.param( - {"field": "meta.date", "operator": ">=", "value": "1969-07-21T20:17:40"}, - Document(meta={"date": "1972-12-11T19:54:58"}), - True, - id=">= operator with greater ISO 8601 datetime Document value", - ), - pytest.param( - {"field": "meta.date", "operator": ">=", "value": "1972-12-11T19:54:58"}, - Document(meta={"date": "1969-07-21T20:17:40"}), - False, - id=">= operator with smaller ISO 8601 datetime Document value", - ), - pytest.param( - {"field": "meta.page", "operator": ">=", "value": 10}, - Document(), - False, - id=">= operator with missing Document value", - ), - pytest.param( - {"field": "meta.page", "operator": ">=", "value": 10}, - Document(meta={"page": None}), - False, - id=">= operator with None Document value", - ), - pytest.param( - {"field": "meta.page", "operator": ">=", "value": None}, - Document(meta={"page": 10}), - False, - id=">= operator with None filter value", - ), - pytest.param( - {"field": "meta.page", "operator": ">=", "value": None}, - Document(meta={"page": None}), - False, - id=">= operator with None Document and filter value", - ), - # < operator params - pytest.param( - {"field": "meta.page", "operator": "<", "value": 10}, - Document(meta={"page": 10}), - False, - id="< operator with equal Document value", - ), - pytest.param( - {"field": "meta.page", "operator": "<", "value": 10}, - Document(meta={"page": 11}), - False, - id="< operator with greater Document value", - ), - pytest.param( - {"field": "meta.page", "operator": "<", "value": 10}, - Document(meta={"page": 9}), - True, - id="< operator with smaller Document value", - ), - pytest.param( - {"field": "meta.date", "operator": "<", "value": "1969-07-21T20:17:40"}, - Document(meta={"date": "1969-07-21T20:17:40"}), - False, - id="< operator with equal ISO 8601 datetime Document value", - ), - pytest.param( - {"field": "meta.date", "operator": "<", "value": "1969-07-21T20:17:40"}, - Document(meta={"date": "1972-12-11T19:54:58"}), - False, - id="< operator with greater ISO 8601 datetime Document value", - ), - pytest.param( - {"field": "meta.date", "operator": "<", "value": "1972-12-11T19:54:58"}, - Document(meta={"date": "1969-07-21T20:17:40"}), - True, - id="< operator with smaller ISO 8601 datetime Document value", - ), - pytest.param( - {"field": "meta.page", "operator": "<", "value": 10}, - Document(), - False, - id="< operator with missing Document value", - ), - pytest.param( - {"field": "meta.page", "operator": "<", "value": 10}, - Document(meta={"page": None}), - False, - id="< operator with None Document value", - ), - pytest.param( - {"field": "meta.page", "operator": "<", "value": None}, - Document(meta={"page": 10}), - False, - id="< operator with None filter value", - ), - pytest.param( - {"field": "meta.page", "operator": "<", "value": None}, - Document(meta={"page": None}), - False, - id="< operator with None Document and filter value", - ), - # <= operator params - pytest.param( - {"field": "meta.page", "operator": "<=", "value": 10}, - Document(meta={"page": 10}), - True, - id="<= operator with equal Document value", - ), - pytest.param( - {"field": "meta.page", "operator": "<=", "value": 10}, - Document(meta={"page": 11}), - False, - id="<= operator with greater Document value", - ), - pytest.param( - {"field": "meta.page", "operator": "<=", "value": 10}, - Document(meta={"page": 9}), - True, - id="<= operator with smaller Document value", - ), - pytest.param( - {"field": "meta.date", "operator": "<=", "value": "1969-07-21T20:17:40"}, - Document(meta={"date": "1969-07-21T20:17:40"}), - True, - id="<= operator with equal ISO 8601 datetime Document value", - ), - pytest.param( - {"field": "meta.date", "operator": "<=", "value": "1969-07-21T20:17:40"}, - Document(meta={"date": "1972-12-11T19:54:58"}), - False, - id="<= operator with greater ISO 8601 datetime Document value", - ), - pytest.param( - {"field": "meta.date", "operator": "<=", "value": "1972-12-11T19:54:58"}, - Document(meta={"date": "1969-07-21T20:17:40"}), - True, - id="<= operator with smaller ISO 8601 datetime Document value", - ), - pytest.param( - {"field": "meta.page", "operator": "<=", "value": 10}, - Document(), - False, - id="<= operator with missing Document value", - ), - pytest.param( - {"field": "meta.page", "operator": "<=", "value": 10}, - Document(meta={"page": None}), - False, - id="<= operator with None Document value", - ), - pytest.param( - {"field": "meta.page", "operator": "<=", "value": None}, - Document(meta={"page": 10}), - False, - id="<= operator with None filter value", - ), - pytest.param( - {"field": "meta.page", "operator": "<=", "value": None}, - Document(meta={"page": None}), - False, - id="<= operator with None Document and filter value", - ), - # in operator params - pytest.param( - {"field": "meta.page", "operator": "in", "value": [9, 10]}, - Document(meta={"page": 1}), - False, - id="in operator with filter value not containing Document value", - ), - pytest.param( - {"field": "meta.page", "operator": "in", "value": [9, 10]}, - Document(meta={"page": 10}), - True, - id="in operator with filter value containing Document value", - ), - # not in operator params - pytest.param( - {"field": "meta.page", "operator": "not in", "value": [9, 10]}, - Document(meta={"page": 1}), - True, - id="not in operator with filter value not containing Document value", - ), - pytest.param( - {"field": "meta.page", "operator": "not in", "value": [9, 10]}, - Document(meta={"page": 10}), - False, - id="not in operator with filter value containing Document value", - ), - # AND operator params - pytest.param( - { - "operator": "AND", - "conditions": [ - {"field": "meta.page", "operator": "==", "value": 10}, - {"field": "meta.type", "operator": "==", "value": "article"}, - ], - }, - Document(meta={"page": 10, "type": "article"}), - True, - id="AND operator with Document matching all conditions", - ), - pytest.param( - { - "operator": "AND", - "conditions": [ - {"field": "meta.page", "operator": "==", "value": 10}, - {"field": "meta.type", "operator": "==", "value": "article"}, - ], - }, - Document(meta={"page": 20, "type": "article"}), - False, - id="AND operator with Document matching a single condition", - ), - pytest.param( - { - "operator": "AND", - "conditions": [ - {"field": "meta.page", "operator": "==", "value": 10}, - {"field": "meta.type", "operator": "==", "value": "article"}, - ], - }, - Document(meta={"page": 11, "value": "blog post"}), - False, - id="AND operator with Document matching no condition", - ), - # OR operator params - pytest.param( - { - "operator": "OR", - "conditions": [ - {"field": "meta.page", "operator": "==", "value": 10}, - {"field": "meta.type", "operator": "==", "value": "article"}, - ], - }, - Document(meta={"page": 10, "type": "article"}), - True, - id="OR operator with Document matching all conditions", - ), - pytest.param( - { - "operator": "OR", - "conditions": [ - {"field": "meta.page", "operator": "==", "value": 10}, - {"field": "meta.type", "operator": "==", "value": "article"}, - ], - }, - Document(meta={"page": 20, "type": "article"}), - True, - id="OR operator with Document matching a single condition", - ), - pytest.param( - { - "operator": "OR", - "conditions": [ - {"field": "meta.page", "operator": "==", "value": 10}, - {"field": "meta.type", "operator": "==", "value": "article"}, - ], - }, - Document(meta={"page": 11, "value": "blog post"}), - False, - id="OR operator with Document matching no condition", - ), - # NOT operator params - pytest.param( - { - "operator": "NOT", - "conditions": [ - {"field": "meta.page", "operator": "==", "value": 10}, - {"field": "meta.type", "operator": "==", "value": "article"}, - ], - }, - Document(meta={"page": 10, "type": "article"}), - False, - id="NOT operator with Document matching all conditions", - ), - pytest.param( - { - "operator": "NOT", - "conditions": [ - {"field": "meta.page", "operator": "==", "value": 10}, - {"field": "meta.type", "operator": "==", "value": "article"}, - ], - }, - Document(meta={"page": 20, "type": "article"}), - True, - id="NOT operator with Document matching a single condition", - ), - pytest.param( - { - "operator": "NOT", - "conditions": [ - {"field": "meta.page", "operator": "==", "value": 10}, - {"field": "meta.type", "operator": "==", "value": "article"}, - ], - }, - Document(meta={"page": 11, "value": "blog post"}), - True, - id="NOT operator with Document matching no condition", - ), -] - - -@pytest.mark.parametrize("filter, document, expected_result", document_matches_filter_data) -def test_document_matches_filter(filter, document, expected_result): - assert document_matches_filter(filter, document) == expected_result - - -document_matches_filter_raises_error_data = [ - # > operator params - pytest.param({"field": "meta.page", "operator": ">", "value": "10"}, id="> operator with string filter value"), - pytest.param({"field": "meta.page", "operator": ">", "value": [10]}, id="> operator with list filter value"), - pytest.param( - {"field": "meta.page", "operator": ">", "value": pd.DataFrame([10])}, - id="> operator with pandas.DataFrame filter value", - ), - # >= operator params - pytest.param({"field": "meta.page", "operator": ">=", "value": "10"}, id=">= operator with string filter value"), - pytest.param({"field": "meta.page", "operator": ">=", "value": [10]}, id=">= operator with list filter value"), - pytest.param( - {"field": "meta.page", "operator": ">=", "value": pd.DataFrame([10])}, - id=">= operator with pandas.DataFrame filter value", - ), - # < operator params - pytest.param({"field": "meta.page", "operator": "<", "value": "10"}, id="< operator with string filter value"), - pytest.param({"field": "meta.page", "operator": "<", "value": [10]}, id="< operator with list filter value"), - pytest.param( - {"field": "meta.page", "operator": "<", "value": pd.DataFrame([10])}, - id="< operator with pandas.DataFrame filter value", - ), - # <= operator params - pytest.param({"field": "meta.page", "operator": "<=", "value": "10"}, id="<= operator with string filter value"), - pytest.param({"field": "meta.page", "operator": "<=", "value": [10]}, id="<= operator with list filter value"), - pytest.param( - {"field": "meta.page", "operator": "<=", "value": pd.DataFrame([10])}, - id="<= operator with pandas.DataFrame filter value", - ), - # in operator params - pytest.param({"field": "meta.page", "operator": "in", "value": 1}, id="in operator with non list filter value"), - # at some point we might want to support any iterable and this test should fail - pytest.param( - {"field": "meta.page", "operator": "in", "value": (10, 11)}, id="in operator with non list filter value" - ), - # not in operator params - pytest.param( - {"field": "meta.page", "operator": "not in", "value": 1}, id="not in operator with non list filter value" - ), - # at some point we might want to support any iterable and this test should fail - pytest.param( - {"field": "meta.page", "operator": "not in", "value": (10, 11)}, id="not in operator with non list filter value" - ), - # Malformed filters - pytest.param( - {"conditions": [{"field": "meta.name", "operator": "==", "value": "test"}]}, id="Missing root operator key" - ), - pytest.param({"operator": "AND"}, id="Missing root conditions key"), - pytest.param({"operator": "==", "value": "test"}, id="Missing condition field key"), - pytest.param({"field": "meta.name", "value": "test"}, id="Missing condition operator key"), - pytest.param({"field": "meta.name", "operator": "=="}, id="Missing condition value key"), -] - - -@pytest.mark.parametrize("filter", document_matches_filter_raises_error_data) -def test_document_matches_filter_raises_error(filter): - with pytest.raises(FilterError): - document = Document(meta={"page": 10}) - document_matches_filter(filter, document) - - -filters_data = [ - pytest.param( - { - "$and": { - "type": {"$eq": "article"}, - "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, - "rating": {"$gte": 3}, - "$or": {"genre": {"$in": ["economy", "politics"]}, "publisher": {"$eq": "nytimes"}}, - } - }, - { - "operator": "AND", - "conditions": [ - {"field": "type", "operator": "==", "value": "article"}, - {"field": "date", "operator": ">=", "value": "2015-01-01"}, - {"field": "date", "operator": "<", "value": "2021-01-01"}, - {"field": "rating", "operator": ">=", "value": 3}, - { - "operator": "OR", - "conditions": [ - {"field": "genre", "operator": "in", "value": ["economy", "politics"]}, - {"field": "publisher", "operator": "==", "value": "nytimes"}, - ], - }, - ], - }, - id="All operators explicit", - ), - pytest.param( - { - "type": "article", - "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, - "rating": {"$gte": 3}, - "$or": {"genre": ["economy", "politics"], "publisher": "nytimes"}, - }, - { - "operator": "AND", - "conditions": [ - {"field": "type", "operator": "==", "value": "article"}, - {"field": "date", "operator": ">=", "value": "2015-01-01"}, - {"field": "date", "operator": "<", "value": "2021-01-01"}, - {"field": "rating", "operator": ">=", "value": 3}, - { - "operator": "OR", - "conditions": [ - {"field": "genre", "operator": "in", "value": ["economy", "politics"]}, - {"field": "publisher", "operator": "==", "value": "nytimes"}, - ], - }, - ], - }, - id="Root $and implicit", - ), - pytest.param( - { - "$or": [ - {"Type": "News Paper", "Date": {"$lt": "2019-01-01"}}, - {"Type": "Blog Post", "Date": {"$gte": "2019-01-01"}}, - ] - }, - { - "operator": "OR", - "conditions": [ - { - "operator": "AND", - "conditions": [ - {"field": "Type", "operator": "==", "value": "News Paper"}, - {"field": "Date", "operator": "<", "value": "2019-01-01"}, - ], - }, - { - "operator": "AND", - "conditions": [ - {"field": "Type", "operator": "==", "value": "Blog Post"}, - {"field": "Date", "operator": ">=", "value": "2019-01-01"}, - ], - }, - ], - }, - id="Root $or with list and multiple comparisons", - ), - pytest.param( - {"text": "A Foo Document 1"}, - {"operator": "AND", "conditions": [{"field": "text", "operator": "==", "value": "A Foo Document 1"}]}, - id="Implicit root $and and field $eq", - ), - pytest.param( - {"$or": {"name": {"$or": [{"$eq": "name_0"}, {"$eq": "name_1"}]}, "number": {"$lt": 1.0}}}, - { - "operator": "OR", - "conditions": [ - { - "operator": "OR", - "conditions": [ - {"field": "name", "operator": "==", "value": "name_0"}, - {"field": "name", "operator": "==", "value": "name_1"}, - ], - }, - {"field": "number", "operator": "<", "value": 1.0}, - ], - }, - id="Root $or with dict and field $or with list", - ), - pytest.param( - {"number": {"$lte": 2, "$gte": 0}, "name": ["name_0", "name_1"]}, - { - "operator": "AND", - "conditions": [ - {"field": "number", "operator": "<=", "value": 2}, - {"field": "number", "operator": ">=", "value": 0}, - {"field": "name", "operator": "in", "value": ["name_0", "name_1"]}, - ], - }, - id="Implicit $and and field $in", - ), - pytest.param( - {"number": {"$and": [{"$lte": 2}, {"$gte": 0}]}}, - { - "operator": "AND", - "conditions": [ - {"field": "number", "operator": "<=", "value": 2}, - {"field": "number", "operator": ">=", "value": 0}, - ], - }, - id="Implicit root $and and field $and with list", - ), - pytest.param( - { - "$not": { - "number": {"$lt": 1.0}, - "$and": {"name": {"$in": ["name_0", "name_1"]}, "$not": {"chapter": {"$eq": "intro"}}}, - } - }, - { - "operator": "NOT", - "conditions": [ - {"field": "number", "operator": "<", "value": 1.0}, - { - "operator": "AND", - "conditions": [ - {"field": "name", "operator": "in", "value": ["name_0", "name_1"]}, - {"operator": "NOT", "conditions": [{"field": "chapter", "operator": "==", "value": "intro"}]}, - ], - }, - ], - }, - id="Root explicit $not", - ), - pytest.param( - {"page": {"$not": 102}}, - {"operator": "NOT", "conditions": [{"field": "page", "operator": "==", "value": 102}]}, - id="Explicit $not with implicit $eq", - ), -] - - -@pytest.mark.parametrize("old_style, new_style", filters_data) -def test_convert(old_style, new_style): - assert convert(old_style) == new_style - - -def test_convert_with_incorrect_input_type(): - with pytest.raises(ValueError): - convert("some string") - - -def test_convert_with_incorrect_filter_nesting(): - with pytest.raises(FilterError): - convert({"number": {"page": "100"}}) - - with pytest.raises(FilterError): - convert({"number": {"page": {"chapter": "intro"}}}) diff --git a/test/prompt/conftest.py b/test/prompt/conftest.py index 9d38e6d0dd..12b850207f 100644 --- a/test/prompt/conftest.py +++ b/test/prompt/conftest.py @@ -23,12 +23,12 @@ def prompt_model(request, haystack_azure_conf): api_key = os.environ.get("OPENAI_API_KEY", "KEY_NOT_FOUND") if api_key is None or api_key == "": api_key = "KEY_NOT_FOUND" - return PromptModel("text-davinci-003", api_key=api_key) + return PromptModel("gpt-3.5-turbo-instruct", api_key=api_key) elif request.param == "azure": api_key = os.environ.get("AZURE_OPENAI_API_KEY", "KEY_NOT_FOUND") if api_key is None or api_key == "": api_key = "KEY_NOT_FOUND" - return PromptModel("text-davinci-003", api_key=api_key, model_kwargs=haystack_azure_conf) + return PromptModel("gpt-3.5-turbo-instruct", api_key=api_key, model_kwargs=haystack_azure_conf) else: return PromptModel("google/flan-t5-base", devices=["cpu"]) diff --git a/test/prompt/invocation_layer/test_amazon_bedrock.py b/test/prompt/invocation_layer/test_amazon_bedrock.py index a605c68ccc..dace56b30e 100644 --- a/test/prompt/invocation_layer/test_amazon_bedrock.py +++ b/test/prompt/invocation_layer/test_amazon_bedrock.py @@ -14,6 +14,7 @@ CohereCommandAdapter, AmazonTitanAdapter, MetaLlama2ChatAdapter, + MistralAIAdapter, ) with LazyImport() as boto3_import: @@ -305,6 +306,13 @@ def test_supports_with_stream_true_for_model_that_does_not_support_streaming(): ("meta.llama2-13b-chat-v1", MetaLlama2ChatAdapter), ("meta.llama2-70b-chat-v1", MetaLlama2ChatAdapter), ("meta.llama2-130b-v5", MetaLlama2ChatAdapter), # artificial + ("meta.llama3-8b-instruct-v1:0", MetaLlama2ChatAdapter), + ("meta.llama3-70b-instruct-v1:0", MetaLlama2ChatAdapter), + ("meta.llama3-130b-instruct-v5:9", MetaLlama2ChatAdapter), # artificial + ("mistral.mistral-7b-instruct-v0:2", MistralAIAdapter), + ("mistral.mixtral-8x7b-instruct-v0:1", MistralAIAdapter), + ("mistral.mistral-large-2402-v1:0", MistralAIAdapter), + ("mistral.mistral-medium-v8:0", MistralAIAdapter), # artificial ("unknown_model", None), ], ) @@ -317,9 +325,183 @@ def test_get_model_adapter(model_name_or_path: str, expected_model_adapter: Opti class TestAnthropicClaudeAdapter: + def test_default_init(self) -> None: + adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=100) + assert adapter.use_messages_api is True + + def test_use_messages_api_false(self) -> None: + adapter = AnthropicClaudeAdapter(model_kwargs={"use_messages_api": False}, max_length=100) + assert adapter.use_messages_api is False + + +class TestAnthropicClaudeAdapterMessagesAPI: def test_prepare_body_with_default_params(self) -> None: layer = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) prompt = "Hello, how are you?" + expected_body = { + "messages": [{"role": "user", "content": "Hello, how are you?"}], + "max_tokens": 99, + "anthropic_version": "bedrock-2023-05-31", + } + + body = layer.prepare_body(prompt) + + assert body == expected_body + + def test_prepare_body_with_custom_inference_params(self) -> None: + layer = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + prompt = "Hello, how are you?" + expected_body = { + "messages": [{"role": "user", "content": "Hello, how are you?"}], + "max_tokens": 50, + "stop_sequences": ["CUSTOM_STOP"], + "temperature": 0.7, + "top_p": 0.8, + "top_k": 5, + "system": "system prompt", + "anthropic_version": "custom_version", + } + + body = layer.prepare_body( + prompt, + temperature=0.7, + top_p=0.8, + top_k=5, + max_tokens=50, + stop_sequences=["CUSTOM_STOP"], + system="system prompt", + anthropic_version="custom_version", + unknown_arg="unknown_value", + ) + + assert body == expected_body + + def test_prepare_body_with_model_kwargs(self) -> None: + layer = AnthropicClaudeAdapter( + model_kwargs={ + "temperature": 0.7, + "top_p": 0.8, + "top_k": 5, + "max_tokens": 50, + "stop_sequences": ["CUSTOM_STOP"], + "system": "system prompt", + "anthropic_version": "custom_version", + "unknown_arg": "unknown_value", + }, + max_length=99, + ) + prompt = "Hello, how are you?" + expected_body = { + "messages": [{"role": "user", "content": "Hello, how are you?"}], + "max_tokens": 50, + "stop_sequences": ["CUSTOM_STOP"], + "temperature": 0.7, + "top_p": 0.8, + "top_k": 5, + "system": "system prompt", + "anthropic_version": "custom_version", + } + + body = layer.prepare_body(prompt) + + assert body == expected_body + + def test_prepare_body_with_model_kwargs_and_custom_inference_params(self) -> None: + layer = AnthropicClaudeAdapter( + model_kwargs={ + "temperature": 0.6, + "top_p": 0.7, + "top_k": 4, + "max_tokens": 49, + "stop_sequences": ["CUSTOM_STOP_MODEL_KWARGS"], + "system": "system prompt", + "anthropic_version": "custom_version", + }, + max_length=99, + ) + prompt = "Hello, how are you?" + expected_body = { + "messages": [{"role": "user", "content": "Hello, how are you?"}], + "max_tokens": 50, + "stop_sequences": ["CUSTOM_STOP_MODEL_KWARGS"], + "temperature": 0.7, + "top_p": 0.8, + "top_k": 5, + "system": "new system prompt", + "anthropic_version": "new_custom_version", + } + + body = layer.prepare_body( + prompt, + temperature=0.7, + top_p=0.8, + top_k=5, + max_tokens=50, + system="new system prompt", + anthropic_version="new_custom_version", + ) + + assert body == expected_body + + def test_get_responses(self) -> None: + adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + response_body = {"content": [{"text": "This is a single response."}]} + expected_responses = ["This is a single response."] + assert adapter.get_responses(response_body) == expected_responses + + def test_get_responses_leading_whitespace(self) -> None: + adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + response_body = {"content": [{"text": "\n\t This is a single response."}]} + expected_responses = ["This is a single response."] + assert adapter.get_responses(response_body) == expected_responses + + def test_get_stream_responses(self) -> None: + stream_mock = MagicMock() + stream_handler_mock = MagicMock() + + stream_mock.__iter__.return_value = [ + {"chunk": {"bytes": b'{"delta": {"text": " This"}}'}}, + {"chunk": {"bytes": b'{"delta": {"text": " is"}}'}}, + {"chunk": {"bytes": b'{"delta": {"text": " a"}}'}}, + {"chunk": {"bytes": b'{"delta": {"text": " single"}}'}}, + {"chunk": {"bytes": b'{"delta": {"text": " response."}}'}}, + ] + + stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received + + adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + expected_responses = ["This is a single response."] + assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + + stream_handler_mock.assert_has_calls( + [ + call(" This", event_data={"delta": {"text": " This"}}), + call(" is", event_data={"delta": {"text": " is"}}), + call(" a", event_data={"delta": {"text": " a"}}), + call(" single", event_data={"delta": {"text": " single"}}), + call(" response.", event_data={"delta": {"text": " response."}}), + ] + ) + + def test_get_stream_responses_empty(self) -> None: + stream_mock = MagicMock() + stream_handler_mock = MagicMock() + + stream_mock.__iter__.return_value = [] + + stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received + + adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + expected_responses = [""] + assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + + stream_handler_mock.assert_not_called() + + +class TestAnthropicClaudeAdapterNoMessagesAPI: + def test_prepare_body_with_default_params(self) -> None: + layer = AnthropicClaudeAdapter(model_kwargs={"use_messages_api": False}, max_length=99) + prompt = "Hello, how are you?" expected_body = { "prompt": "\n\nHuman: Hello, how are you?\n\nAssistant:", "max_tokens_to_sample": 99, @@ -331,7 +513,7 @@ def test_prepare_body_with_default_params(self) -> None: assert body == expected_body def test_prepare_body_with_custom_inference_params(self) -> None: - layer = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + layer = AnthropicClaudeAdapter(model_kwargs={"use_messages_api": False}, max_length=99) prompt = "Hello, how are you?" expected_body = { "prompt": "\n\nHuman: Hello, how are you?\n\nAssistant:", @@ -357,6 +539,7 @@ def test_prepare_body_with_custom_inference_params(self) -> None: def test_prepare_body_with_model_kwargs(self) -> None: layer = AnthropicClaudeAdapter( model_kwargs={ + "use_messages_api": False, "temperature": 0.7, "top_p": 0.8, "top_k": 5, @@ -383,6 +566,7 @@ def test_prepare_body_with_model_kwargs(self) -> None: def test_prepare_body_with_model_kwargs_and_custom_inference_params(self) -> None: layer = AnthropicClaudeAdapter( model_kwargs={ + "use_messages_api": False, "temperature": 0.6, "top_p": 0.7, "top_k": 4, @@ -406,13 +590,13 @@ def test_prepare_body_with_model_kwargs_and_custom_inference_params(self) -> Non assert body == expected_body def test_get_responses(self) -> None: - adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + adapter = AnthropicClaudeAdapter(model_kwargs={"use_messages_api": False}, max_length=99) response_body = {"completion": "This is a single response."} expected_responses = ["This is a single response."] assert adapter.get_responses(response_body) == expected_responses def test_get_responses_leading_whitespace(self) -> None: - adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + adapter = AnthropicClaudeAdapter(model_kwargs={"use_messages_api": False}, max_length=99) response_body = {"completion": "\n\t This is a single response."} expected_responses = ["This is a single response."] assert adapter.get_responses(response_body) == expected_responses @@ -431,7 +615,7 @@ def test_get_stream_responses(self) -> None: stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received - adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + adapter = AnthropicClaudeAdapter(model_kwargs={"use_messages_api": False}, max_length=99) expected_responses = ["This is a single response."] assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses @@ -453,7 +637,7 @@ def test_get_stream_responses_empty(self) -> None: stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received - adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + adapter = AnthropicClaudeAdapter(model_kwargs={"use_messages_api": False}, max_length=99) expected_responses = [""] assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses @@ -1031,3 +1215,106 @@ def test_get_stream_responses_empty(self) -> None: assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses stream_handler_mock.assert_not_called() + + +class TestMistralAIAdapter: + def test_prepare_body_with_default_params(self) -> None: + layer = MistralAIAdapter(model_kwargs={}, max_length=99) + prompt = "Hello, how are you?" + expected_body = {"prompt": "Hello, how are you?", "max_tokens": 99} + + body = layer.prepare_body(prompt) + + assert body == expected_body + + def test_prepare_body_with_custom_inference_params(self) -> None: + layer = MistralAIAdapter(model_kwargs={}, max_length=99) + prompt = "Hello, how are you?" + expected_body = {"prompt": "Hello, how are you?", "max_tokens": 50, "temperature": 0.7, "top_p": 0.8} + + body = layer.prepare_body(prompt, temperature=0.7, top_p=0.8, max_tokens=50, unknown_arg="unknown_value") + + assert body == expected_body + + def test_prepare_body_with_model_kwargs(self) -> None: + layer = MistralAIAdapter( + model_kwargs={"temperature": 0.7, "top_p": 0.8, "max_tokens": 50, "unknown_arg": "unknown_value"}, + max_length=99, + ) + prompt = "Hello, how are you?" + expected_body = {"prompt": "Hello, how are you?", "max_tokens": 50, "temperature": 0.7, "top_p": 0.8} + + body = layer.prepare_body(prompt) + + assert body == expected_body + + def test_prepare_body_with_model_kwargs_and_custom_inference_params(self) -> None: + layer = MistralAIAdapter( + model_kwargs={"temperature": 0.6, "top_p": 0.7, "top_k": 4, "max_tokens": 49}, max_length=99 + ) + prompt = "Hello, how are you?" + expected_body = { + "prompt": "Hello, how are you?", + "max_tokens": 50, + "temperature": 0.7, + "top_p": 0.7, + "top_k": 4, + } + + body = layer.prepare_body(prompt, temperature=0.7, max_tokens=50) + + assert body == expected_body + + def test_get_responses(self) -> None: + adapter = MistralAIAdapter(model_kwargs={}, max_length=99) + response_body = {"outputs": [{"text": "This is a single response."}]} + expected_responses = ["This is a single response."] + assert adapter.get_responses(response_body) == expected_responses + + def test_get_responses_leading_whitespace(self) -> None: + adapter = MistralAIAdapter(model_kwargs={}, max_length=99) + response_body = {"outputs": [{"text": "\n\t This is a single response."}]} + expected_responses = ["This is a single response."] + assert adapter.get_responses(response_body) == expected_responses + + def test_get_stream_responses(self) -> None: + stream_mock = MagicMock() + stream_handler_mock = MagicMock() + + stream_mock.__iter__.return_value = [ + {"chunk": {"bytes": b'{"outputs": [{"text": " This"}]}'}}, + {"chunk": {"bytes": b'{"outputs": [{"text": " is"}]}'}}, + {"chunk": {"bytes": b'{"outputs": [{"text": " a"}]}'}}, + {"chunk": {"bytes": b'{"outputs": [{"text": " single"}]}'}}, + {"chunk": {"bytes": b'{"outputs": [{"text": " response."}]}'}}, + ] + + stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received + + adapter = MistralAIAdapter(model_kwargs={}, max_length=99) + expected_responses = ["This is a single response."] + assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + + stream_handler_mock.assert_has_calls( + [ + call(" This", event_data={"outputs": [{"text": " This"}]}), + call(" is", event_data={"outputs": [{"text": " is"}]}), + call(" a", event_data={"outputs": [{"text": " a"}]}), + call(" single", event_data={"outputs": [{"text": " single"}]}), + call(" response.", event_data={"outputs": [{"text": " response."}]}), + ] + ) + + def test_get_stream_responses_empty(self) -> None: + stream_mock = MagicMock() + stream_handler_mock = MagicMock() + + stream_mock.__iter__.return_value = [] + + stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received + + adapter = MistralAIAdapter(model_kwargs={}, max_length=99) + expected_responses = [""] + assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + + stream_handler_mock.assert_not_called() diff --git a/test/prompt/invocation_layer/test_chatgpt.py b/test/prompt/invocation_layer/test_chatgpt.py index c1b816d493..d84e3dc384 100644 --- a/test/prompt/invocation_layer/test_chatgpt.py +++ b/test/prompt/invocation_layer/test_chatgpt.py @@ -48,12 +48,12 @@ def test_chatgpt_token_limit_warning_single_prompt(mock_openai_tokenizer, caplog model_name_or_path="gpt-3.5-turbo", api_key="fake_api_key", api_base="https://fake_api_base.com", - max_length=4090, + max_length=16379, ) with caplog.at_level(logging.WARNING): _ = invocation_layer._ensure_token_limit(prompt="This is a test for a mock openai tokenizer.") assert "The prompt has been truncated from" in caplog.text - assert "and answer length (4090 tokens) fit within the max token limit (4096 tokens)." in caplog.text + assert "and answer length (16379 tokens) fit within the max token limit (16385 tokens)." in caplog.text @pytest.mark.unit @@ -70,7 +70,7 @@ def test_chatgpt_token_limit_warning_with_messages(mock_openai_tokenizer, caplog model_name_or_path="gpt-3.5-turbo", api_key="fake_api_key", api_base="https://fake_api_base.com", - max_length=4060, + max_length=16379, ) with pytest.raises(ValueError): _ = invocation_layer._ensure_token_limit(prompt=messages) diff --git a/test/prompt/invocation_layer/test_hugging_face.py b/test/prompt/invocation_layer/test_hugging_face.py index 3b56cab785..fcdf399d93 100644 --- a/test/prompt/invocation_layer/test_hugging_face.py +++ b/test/prompt/invocation_layer/test_hugging_face.py @@ -59,6 +59,9 @@ def test_constructor_with_model_name_only(mock_pipeline, mock_get_task): assert kwargs["task"] == "text2text-generation" assert kwargs["model"] == "google/flan-t5-base" + # use_auth_token is no longer used in HuggingFace pipelines and should never be passed + assert "use_auth_token" not in kwargs + # no matter what kwargs we pass or don't pass, there are always 14 predefined kwargs passed to the pipeline assert len(kwargs) == 14 @@ -76,7 +79,7 @@ def test_constructor_with_model_name_only(mock_pipeline, mock_get_task): "pipeline_class", "use_fast", "revision", - "use_auth_token", + "token", "trust_remote_code", ] diff --git a/test/prompt/invocation_layer/test_openai.py b/test/prompt/invocation_layer/test_openai.py index 5ae3458788..9438e0fb38 100644 --- a/test/prompt/invocation_layer/test_openai.py +++ b/test/prompt/invocation_layer/test_openai.py @@ -41,22 +41,25 @@ def test_custom_api_base(mock_open_ai_request, load_openai_tokenizer): @pytest.mark.unit def test_openai_token_limit_warning(mock_openai_tokenizer, caplog): invocation_layer = OpenAIInvocationLayer( - model_name_or_path="text-ada-001", api_key="fake_api_key", api_base="https://fake_api_base.com", max_length=2045 + model_name_or_path="babbage-002", api_key="fake_api_key", api_base="https://fake_api_base.com", max_length=16385 ) with caplog.at_level(logging.WARNING): _ = invocation_layer._ensure_token_limit(prompt="This is a test for a mock openai tokenizer.") assert "The prompt has been truncated from" in caplog.text - assert "and answer length (2045 tokens) fit within the max token limit (2049 tokens)." in caplog.text + assert "and answer length (16385 tokens) fit within the max token limit (16384 tokens)." in caplog.text @pytest.mark.unit @pytest.mark.parametrize( "model_name,max_tokens_limit", [ - ("text-davinci-003", 4097), - ("gpt-3.5-turbo", 4096), - ("gpt-3.5-turbo-16k", 16384), + ("gpt-3.5-turbo-instruct", 4096), + ("gpt-3.5-turbo", 16385), + ("gpt-3.5-turbo-16k", 16385), ("gpt-4-32k", 32768), + ("gpt-4-1106", 128000), + ("gpt-4-turbo-preview", 128000), + ("gpt-4-0125-preview", 128000), ("gpt-4", 8192), ], ) @@ -76,10 +79,13 @@ def test_openai_token_limit_warning_not_triggered(caplog, mock_openai_tokenizer, @pytest.mark.parametrize( "model_name,max_tokens_limit", [ - ("text-davinci-003", 4097), - ("gpt-3.5-turbo", 4096), - ("gpt-3.5-turbo-16k", 16384), + ("gpt-3.5-turbo-instruct", 4096), + ("gpt-3.5-turbo", 16385), + ("gpt-3.5-turbo-16k", 16385), ("gpt-4-32k", 32768), + ("gpt-4-1106", 128000), + ("gpt-4-turbo-preview", 128000), + ("gpt-4-0125-preview", 128000), ("gpt-4", 8192), ], ) diff --git a/test/prompt/test_prompt_model.py b/test/prompt/test_prompt_model.py index 9e7c2b0f51..5406172f63 100644 --- a/test/prompt/test_prompt_model.py +++ b/test/prompt/test_prompt_model.py @@ -4,7 +4,7 @@ import pytest from haystack.nodes.prompt.prompt_model import PromptModel -from haystack.nodes.prompt.invocation_layer import PromptModelInvocationLayer, HFLocalInvocationLayer +from haystack.nodes.prompt.invocation_layer import PromptModelInvocationLayer from .conftest import create_mock_layer_that_supports @@ -39,6 +39,16 @@ def test_constructor_with_no_supported_model(): PromptModel("some-random-model") +@pytest.mark.unit +def test_constructor_with_invocation_layer_class_string(mock_auto_tokenizer): + model = PromptModel( + invocation_layer_class="haystack.nodes.prompt.invocation_layer.CohereInvocationLayer", api_key="fake_api_key" + ) + from haystack.nodes.prompt.invocation_layer import CohereInvocationLayer + + assert isinstance(model.model_invocation_layer, CohereInvocationLayer) + + @pytest.mark.asyncio async def test_ainvoke(): def async_return(result): diff --git a/test/prompt/test_prompt_node.py b/test/prompt/test_prompt_node.py index 972a04be18..aba2ed8833 100644 --- a/test/prompt/test_prompt_node.py +++ b/test/prompt/test_prompt_node.py @@ -216,7 +216,7 @@ def test_azure_vs_open_ai_invocation_layer_selection(): node = PromptNode("gpt-4", api_key="some_key", model_kwargs=azure_model_kwargs) assert isinstance(node.prompt_model.model_invocation_layer, AzureChatGPTInvocationLayer) - node = PromptNode("text-davinci-003", api_key="some_key", model_kwargs=azure_model_kwargs) + node = PromptNode("gpt-3.5-turbo-instruct", api_key="some_key", model_kwargs=azure_model_kwargs) assert isinstance(node.prompt_model.model_invocation_layer, AzureOpenAIInvocationLayer) node = PromptNode("gpt-4", api_key="some_key") @@ -224,7 +224,7 @@ def test_azure_vs_open_ai_invocation_layer_selection(): node.prompt_model.model_invocation_layer, AzureChatGPTInvocationLayer ) - node = PromptNode("text-davinci-003", api_key="some_key") + node = PromptNode("gpt-3.5-turbo-instruct", api_key="some_key") assert isinstance(node.prompt_model.model_invocation_layer, OpenAIInvocationLayer) and not isinstance( node.prompt_model.model_invocation_layer, AzureChatGPTInvocationLayer ) @@ -432,7 +432,6 @@ def test_pipeline_with_prompt_template_at_query_time(prompt_model): ) -@pytest.mark.skip @pytest.mark.integration def test_pipeline_with_prompt_template_and_nested_shaper_yaml(tmp_path): # TODO: This can be a Shaper unit test? @@ -444,7 +443,7 @@ def test_pipeline_with_prompt_template_and_nested_shaper_yaml(tmp_path): - name: template_with_nested_shaper type: PromptTemplate params: - prompt: "Given the context please answer the question. Context: {{documents}}; Question: {{query}}; Answer: " + prompt: "Given the context please answer the question. Context: {documents}; Question: {query}; Answer: " output_parser: type: AnswerParser - name: p1 @@ -850,7 +849,7 @@ def test_complex_pipeline_with_all_features(tmp_path, haystack_openai_config): - name: pmodel_openai type: PromptModel params: - model_name_or_path: text-davinci-003 + model_name_or_path: gpt-3.5-turbo-instruct model_kwargs: temperature: 0.9 max_tokens: 64 @@ -1052,7 +1051,7 @@ def test_content_moderation_gpt_3(): OpenAIInvocationLayer. """ prompt_node = PromptNode( - model_name_or_path="text-davinci-003", api_key="key", model_kwargs={"moderate_content": True} + model_name_or_path="gpt-3.5-turbo-instruct", api_key="key", model_kwargs={"moderate_content": True} ) with patch("haystack.nodes.prompt.invocation_layer.open_ai.check_openai_policy_violation") as mock_check, patch( "haystack.nodes.prompt.invocation_layer.open_ai.openai_request" @@ -1202,3 +1201,30 @@ async def test_aprompt(mock_model): mock_model.return_value.ainvoke = AsyncMock() await node._aprompt(PromptTemplate("test template")) mock_model.return_value.ainvoke.assert_awaited_once() + + +@pytest.mark.unit +@patch("haystack.nodes.prompt.prompt_node.PromptModel") +def test_prompt_no_truncation(mock_model, caplog): + node = PromptNode(truncate=False) + mock_model.return_value.invoke = MagicMock() + prompt = "Repeating text" * 200 + "Docs: Berlin is an amazing city.; Answer:" + with caplog.at_level(logging.DEBUG): + _ = node.prompt(PromptTemplate(prompt)) + assert prompt in caplog.text + + +@pytest.mark.unit +def test_run_with_empty_inputs(): + mock_model = MagicMock(spec=PromptModel) + mock_model.invoke.return_value = ["mock answer"] + node = PromptNode(mock_model, default_prompt_template="question-answering") + result, _ = node.run(query="", documents=[]) + + # validate output variable present + assert "answers" in result + assert len(result["answers"]) == 1 + + # and that so-called invocation context contains the right keys + assert "invocation_context" in result + assert all(item in result["invocation_context"] for item in ["query", "documents", "answers", "prompts"]) diff --git a/test/utils/test_openai_utils.py b/test/utils/test_openai_utils.py index 7126542f0c..7b67be73be 100644 --- a/test/utils/test_openai_utils.py +++ b/test/utils/test_openai_utils.py @@ -14,45 +14,43 @@ @pytest.mark.unit -def test_openai_text_completion_tokenization_details_gpt_default(): - tokenizer_name, max_tokens_limit = _openai_text_completion_tokenization_details(model_name="not-recognized-name") - assert tokenizer_name == "gpt2" - assert max_tokens_limit == 2049 - - -@pytest.mark.unit -def test_openai_text_completion_tokenization_details_gpt_davinci(): - tokenizer_name, max_tokens_limit = _openai_text_completion_tokenization_details(model_name="text-davinci-003") - assert tokenizer_name == "p50k_base" - assert max_tokens_limit == 4097 - - -@pytest.mark.unit -def test_openai_text_completion_tokenization_details_gpt3_5_azure(): - tokenizer_name, max_tokens_limit = _openai_text_completion_tokenization_details(model_name="gpt-35-turbo") - assert tokenizer_name == "cl100k_base" - assert max_tokens_limit == 4096 - - -@pytest.mark.unit -def test_openai_text_completion_tokenization_details_gpt3_5(): - tokenizer_name, max_tokens_limit = _openai_text_completion_tokenization_details(model_name="gpt-3.5-turbo") - assert tokenizer_name == "cl100k_base" - assert max_tokens_limit == 4096 - - -@pytest.mark.unit -def test_openai_text_completion_tokenization_details_gpt_4(): - tokenizer_name, max_tokens_limit = _openai_text_completion_tokenization_details(model_name="gpt-4") - assert tokenizer_name == "cl100k_base" - assert max_tokens_limit == 8192 - - -@pytest.mark.unit -def test_openai_text_completion_tokenization_details_gpt_4_32k(): - tokenizer_name, max_tokens_limit = _openai_text_completion_tokenization_details(model_name="gpt-4-32k") - assert tokenizer_name == "cl100k_base" - assert max_tokens_limit == 32768 +@pytest.mark.parametrize( + "model_name,tok_name,max_tok_limit", + [ + # Default + ("not-recognized-name", "cl100k_base", 4096), + # GPT-3.5 + ("gpt-3.5-turbo-0125", "cl100k_base", 16385), + ("gpt-3.5-turbo-instruct", "cl100k_base", 4096), + ("gpt-3.5-turbo-0613", "cl100k_base", 4096), + ("gpt-3.5-turbo", "cl100k_base", 16385), + ("gpt-3.5-turbo-1106", "cl100k_base", 16385), + ("gpt-3.5-turbo-16k", "cl100k_base", 16385), + ("gpt-3.5-turbo-16k-0613", "cl100k_base", 16385), + # GPT 4 + ("gpt-4-0125-preview", "cl100k_base", 128000), + ("gpt-4-turbo-preview", "cl100k_base", 128000), + ("gpt-4-1106-preview", "cl100k_base", 128000), + ("gpt-4-vision-preview", "cl100k_base", 128000), + ("gpt-4-1106-vision-preview", "cl100k_base", 128000), + ("gpt-4", "cl100k_base", 8192), + ("gpt-4-0613", "cl100k_base", 8192), + ("gpt-4-32k", "cl100k_base", 32768), + ("gpt-4-32k-0613", "cl100k_base", 32768), + ("gpt-4-1106", "cl100k_base", 128000), + # GPT-35 Azure + ("gpt-35-turbo-instruct", "cl100k_base", 4096), + ("gpt-35-turbo", "cl100k_base", 16385), + ("gpt-35-turbo-16k", "cl100k_base", 16385), + # davinci and babbage + ("davinci-002", "cl100k_base", 16384), + ("babbage-002", "cl100k_base", 16384), + ], +) +def test_openai_text_completion_tokenization(model_name, tok_name, max_tok_limit): + tokenizer_name, max_tokens_limit = _openai_text_completion_tokenization_details(model_name=model_name) + assert tokenizer_name == tok_name + assert max_tokens_limit == max_tok_limit @pytest.mark.unit diff --git a/test/utils/test_torch_utils.py b/test/utils/test_torch_utils.py new file mode 100644 index 0000000000..824edd3102 --- /dev/null +++ b/test/utils/test_torch_utils.py @@ -0,0 +1,29 @@ +import pytest +import torch + +from haystack.utils.torch_utils import resolve_torch_dtype + + +def test_extract_torch_dtype() -> None: + torch_dtype = resolve_torch_dtype(**{"torch_dtype": torch.float16}) + assert torch_dtype == torch.float16 + + +def test_extract_torch_dtype_none() -> None: + torch_dtype = resolve_torch_dtype(**{}) + assert torch_dtype is None + + +def test_extract_torch_dtype_str() -> None: + torch_dtype = resolve_torch_dtype(**{"torch_dtype": "torch.float16"}) + assert torch_dtype == torch.float16 + + +def test_extract_torch_dtype_auto() -> None: + torch_dtype = resolve_torch_dtype(**{"torch_dtype": "auto"}) + assert torch_dtype == "auto" + + +def test_extract_torch_dtype_invalid() -> None: + with pytest.raises(ValueError): + _ = resolve_torch_dtype(**{"torch_dtype": "random string"})