diff --git a/.evergreen/combine-coverage.sh b/.evergreen/combine-coverage.sh index c31f755bd9..36266c1842 100755 --- a/.evergreen/combine-coverage.sh +++ b/.evergreen/combine-coverage.sh @@ -3,8 +3,7 @@ # Coverage combine merges (and removes) all the coverage files and # generates a new .coverage file in the current directory. -set -o xtrace # Write all commands first to stderr -set -o errexit # Exit the script with error if any of the commands fail +set -eu . .evergreen/utils.sh diff --git a/.evergreen/config.yml b/.evergreen/config.yml index f854f6bd3d..d83a5620df 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -25,6 +25,7 @@ timeout: binary: ls -la include: + - filename: .evergreen/generated_configs/functions.yml - filename: .evergreen/generated_configs/tasks.yml - filename: .evergreen/generated_configs/variants.yml @@ -42,7 +43,7 @@ functions: # Make an evergreen expansion file with dynamic values - command: subprocess.exec params: - include_expansions_in_env: ["is_patch", "project", "version_id", "AUTH", "SSL", "test_encryption", "test_encryption_pyopenssl", "test_crypt_shared", "test_pyopenssl", "SETDEFAULTENCODING", "test_loadbalancer", "test_serverless", "SKIP_CSOT_TESTS", "MONGODB_STARTED", "DISABLE_TEST_COMMANDS", "GREEN_FRAMEWORK", "NO_EXT", "COVERAGE", "COMPRESSORS", "TEST_SUITES", "MONGODB_API_VERSION", "skip_crypt_shared", "VERSION", "TOPOLOGY", "STORAGE_ENGINE", "ORCHESTRATION_FILE", "REQUIRE_API_VERSION", "LOAD_BALANCER", "skip_web_identity_auth_test", "skip_ECS_auth_test"] + include_expansions_in_env: ["is_patch", "project", "version_id"] binary: bash working_dir: "src" args: @@ -52,147 +53,6 @@ functions: params: file: src/expansion.yml - "upload coverage" : - - command: ec2.assume_role - params: - role_arn: ${assume_role_arn} - - command: s3.put - params: - aws_key: ${AWS_ACCESS_KEY_ID} - aws_secret: ${AWS_SECRET_ACCESS_KEY} - aws_session_token: ${AWS_SESSION_TOKEN} - local_file: src/.coverage - optional: true - # Upload the coverage report for all tasks in a single build to the same directory. - remote_file: coverage/${revision}/${version_id}/coverage/coverage.${build_variant}.${task_name} - bucket: ${bucket_name} - permissions: public-read - content_type: text/html - display_name: "Raw Coverage Report" - - "download and merge coverage" : - - command: ec2.assume_role - params: - role_arn: ${assume_role_arn} - - command: subprocess.exec - params: - silent: true - binary: bash - working_dir: "src" - include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"] - args: - - .evergreen/scripts/download-and-merge-coverage.sh - - ${bucket_name} - - ${revision} - - ${version_id} - - command: subprocess.exec - params: - working_dir: "src" - binary: bash - args: - - .evergreen/combine-coverage.sh - # Upload the resulting html coverage report. - - command: subprocess.exec - params: - silent: true - binary: bash - working_dir: "src" - include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"] - args: - - .evergreen/scripts/upload-coverage-report.sh - - ${bucket_name} - - ${revision} - - ${version_id} - # Attach the index.html with s3.put so it shows up in the Evergreen UI. - - command: s3.put - params: - aws_key: ${AWS_ACCESS_KEY_ID} - aws_secret: ${AWS_SECRET_ACCESS_KEY} - aws_session_token: ${AWS_SESSION_TOKEN} - local_file: src/htmlcov/index.html - remote_file: coverage/${revision}/${version_id}/htmlcov/index.html - bucket: ${bucket_name} - permissions: public-read - content_type: text/html - display_name: "Coverage Report HTML" - - "upload mo artifacts": - - command: ec2.assume_role - params: - role_arn: ${assume_role_arn} - - command: archive.targz_pack - params: - target: "mongo-coredumps.tgz" - source_dir: "./" - include: - - "./**.core" - - "./**.mdmp" # Windows: minidumps - - command: s3.put - params: - aws_key: ${AWS_ACCESS_KEY_ID} - aws_secret: ${AWS_SECRET_ACCESS_KEY} - aws_session_token: ${AWS_SESSION_TOKEN} - local_file: mongo-coredumps.tgz - remote_file: ${build_variant}/${revision}/${version_id}/${build_id}/coredumps/${task_id}-${execution}-mongodb-coredumps.tar.gz - bucket: ${bucket_name} - permissions: public-read - content_type: ${content_type|application/gzip} - display_name: Core Dumps - Execution - optional: true - - command: s3.put - params: - aws_key: ${AWS_ACCESS_KEY_ID} - aws_secret: ${AWS_SECRET_ACCESS_KEY} - aws_session_token: ${AWS_SESSION_TOKEN} - local_file: ${DRIVERS_TOOLS}/.evergreen/test_logs.tar.gz - remote_file: ${build_variant}/${revision}/${version_id}/${build_id}/logs/${task_id}-${execution}-drivers-tools-logs.tar.gz - bucket: ${bucket_name} - permissions: public-read - content_type: ${content_type|application/x-gzip} - display_name: "drivers-tools-logs.tar.gz" - - "upload working dir": - - command: ec2.assume_role - params: - role_arn: ${assume_role_arn} - - command: archive.targz_pack - params: - target: "working-dir.tar.gz" - source_dir: ${PROJECT_DIRECTORY}/ - include: - - "./**" - - command: s3.put - params: - aws_key: ${AWS_ACCESS_KEY_ID} - aws_secret: ${AWS_SECRET_ACCESS_KEY} - aws_session_token: ${AWS_SESSION_TOKEN} - local_file: working-dir.tar.gz - remote_file: ${build_variant}/${revision}/${version_id}/${build_id}/artifacts/${task_id}-${execution}-working-dir.tar.gz - bucket: ${bucket_name} - permissions: public-read - content_type: ${content_type|application/x-gzip} - display_name: "working-dir.tar.gz" - - command: archive.targz_pack - params: - target: "drivers-dir.tar.gz" - source_dir: ${DRIVERS_TOOLS} - include: - - "./**" - exclude_files: - # Windows cannot read the mongod *.lock files because they are locked. - - "*.lock" - - command: s3.put - params: - aws_key: ${AWS_ACCESS_KEY_ID} - aws_secret: ${AWS_SECRET_ACCESS_KEY} - aws_session_token: ${AWS_SESSION_TOKEN} - local_file: drivers-dir.tar.gz - remote_file: ${build_variant}/${revision}/${version_id}/${build_id}/artifacts/${task_id}-${execution}-drivers-dir.tar.gz - bucket: ${bucket_name} - permissions: public-read - content_type: ${content_type|application/x-gzip} - display_name: "drivers-dir.tar.gz" - "upload test results": - command: attach.results params: @@ -201,309 +61,44 @@ functions: params: file: "src/xunit-results/TEST-*.xml" - "bootstrap mongo-orchestration": + "run server": - command: subprocess.exec params: binary: bash - include_expansions_in_env: ["VERSION", "TOPOLOGY", "AUTH", "SSL", "ORCHESTRATION_FILE", "LOAD_BALANCER"] - args: - - src/.evergreen/scripts/run-with-env.sh - - src/.evergreen/scripts/bootstrap-mongo-orchestration.sh - - command: expansions.update - params: - file: mo-expansion.yml + working_dir: "src" + include_expansions_in_env: [VERSION, TOPOLOGY, AUTH, SSL, ORCHESTRATION_FILE, PYTHON_BINARY, PYTHON_VERSION, + STORAGE_ENGINE, REQUIRE_API_VERSION, DRIVERS_TOOLS, TEST_CRYPT_SHARED, AUTH_AWS, LOAD_BALANCER, LOCAL_ATLAS] + args: [.evergreen/just.sh, run-server, "${TEST_NAME}"] - command: expansions.update params: - updates: - - key: MONGODB_STARTED - value: "1" - - "bootstrap data lake": - - command: subprocess.exec - type: setup - params: - binary: bash - args: - - ${DRIVERS_TOOLS}/.evergreen/atlas_data_lake/pull-mongohouse-image.sh - - command: subprocess.exec - type: setup - params: - binary: bash - args: - - ${DRIVERS_TOOLS}/.evergreen/atlas_data_lake/run-mongohouse-image.sh - - "stop mongo-orchestration": - - command: subprocess.exec - params: - binary: bash - args: - - ${DRIVERS_TOOLS}/.evergreen/stop-orchestration.sh + file: ${DRIVERS_TOOLS}/mo-expansion.yml - "run mod_wsgi tests": + "run just script": - command: subprocess.exec type: test params: - include_expansions_in_env: [MOD_WSGI_VERSION, MOD_WSGI_EMBEDDED, "PYTHON_BINARY"] - working_dir: "src" - binary: bash - args: - - .evergreen/scripts/run-with-env.sh - - .evergreen/scripts/run-mod-wsgi-tests.sh - - "run mockupdb tests": - - command: subprocess.exec - type: test - params: - include_expansions_in_env: ["PYTHON_BINARY"] - working_dir: "src" - binary: bash - args: - - .evergreen/scripts/run-with-env.sh - - .evergreen/scripts/run-mockupdb-tests.sh - - "run doctests": - - command: subprocess.exec - type: test - params: - include_expansions_in_env: [ "PYTHON_BINARY" ] - working_dir: "src" - binary: bash - args: - - .evergreen/scripts/run-with-env.sh - - .evergreen/scripts/run-doctests.sh + include_expansions_in_env: [AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN] + binary: bash + working_dir: "src" + args: [.evergreen/just.sh, "${JUSTFILE_TARGET}"] "run tests": - - command: subprocess.exec - params: - include_expansions_in_env: ["TEST_DATA_LAKE", "PYTHON_BINARY", "AUTH", "SSL", "TEST_INDEX_MANAGEMENT", "CRYPT_SHARED_LIB_PATH", "test_encryption", "test_encryption_pyopenssl", "test_crypt_shared", "test_pyopenssl", "test_loadbalancer", "test_serverless", "ORCHESTRATION_FILE"] - binary: bash - working_dir: "src" - args: - - .evergreen/scripts/setup-tests.sh - - command: subprocess.exec - params: - working_dir: "src" - binary: bash - background: true - include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"] - args: - - .evergreen/scripts/run-with-env.sh - - .evergreen/scripts/setup-encryption.sh - command: subprocess.exec type: test params: - working_dir: "src" + include_expansions_in_env: [AUTH, SSL, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, + AWS_SESSION_TOKEN, COVERAGE, PYTHON_BINARY, LIBMONGOCRYPT_URL, MONGODB_URI, PYTHON_VERSION, + DISABLE_TEST_COMMANDS, GREEN_FRAMEWORK, NO_EXT, COMPRESSORS, MONGODB_API_VERSION, DEBUG_LOG, + ORCHESTRATION_FILE, OCSP_SERVER_TYPE, VERSION] binary: bash - include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN", "PYTHON_BINARY", "TEST_DATA_LAKE", "TEST_INDEX_MANAGEMENT", "CRYPT_SHARED_LIB_PATH", "SINGLE_MONGOS_LB_URI", "MULTI_MONGOS_LB_URI", "TEST_SUITES"] - args: - - .evergreen/scripts/run-with-env.sh - - .evergreen/scripts/run-tests.sh - - "run direct tests": - - command: subprocess.exec - type: test - params: working_dir: "src" - binary: bash - include_expansions_in_env: ["PYTHON_BINARY"] - args: [ .evergreen/scripts/run-direct-tests.sh ] - - "run enterprise auth tests": + args: [.evergreen/just.sh, setup-tests, "${TEST_NAME}", "${SUB_TEST_NAME}"] - command: subprocess.exec type: test params: - binary: bash working_dir: "src" - include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN", "PYTHON_BINARY"] - args: - - .evergreen/scripts/run-with-env.sh - - .evergreen/scripts/run-enterprise-auth-tests.sh - - "run atlas tests": - - command: subprocess.exec - type: test - params: binary: bash - include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN", "PYTHON_BINARY"] - working_dir: "src" - args: - - .evergreen/scripts/run-with-env.sh - - .evergreen/scripts/run-atlas-tests.sh - - "get aws auth secrets": - - command: subprocess.exec - type: test - params: - include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"] - binary: bash - args: - - ${DRIVERS_TOOLS}/.evergreen/auth_aws/setup-secrets.sh - - "run aws auth test with regular aws credentials": - - command: subprocess.exec - params: - include_expansions_in_env: ["TEST_DATA_LAKE", "TEST_INDEX_MANAGEMENT", "CRYPT_SHARED_LIB_PATH", "test_encryption", "test_encryption_pyopenssl", "test_crypt_shared", "test_pyopenssl", "test_loadbalancer", "test_serverless", "ORCHESTRATION_FILE"] - binary: bash - working_dir: "src" - args: - - .evergreen/scripts/setup-tests.sh - - command: subprocess.exec - type: test - params: - include_expansions_in_env: ["DRIVERS_TOOLS", "skip_EC2_auth_test"] - binary: bash - working_dir: "src" - args: - - .evergreen/scripts/run-with-env.sh - - .evergreen/scripts/run-mongodb-aws-test.sh - - regular - - "run aws auth test with assume role credentials": - - command: subprocess.exec - params: - include_expansions_in_env: [ "TEST_DATA_LAKE", "TEST_INDEX_MANAGEMENT", "CRYPT_SHARED_LIB_PATH", "test_encryption", "test_encryption_pyopenssl", "test_crypt_shared", "test_pyopenssl", "test_loadbalancer", "test_serverless", "ORCHESTRATION_FILE" ] - binary: bash - working_dir: "src" - args: - - .evergreen/scripts/setup-tests.sh - - command: subprocess.exec - type: test - params: - include_expansions_in_env: ["DRIVERS_TOOLS", "skip_EC2_auth_test"] - binary: bash - working_dir: "src" - args: - - .evergreen/scripts/run-with-env.sh - - .evergreen/scripts/run-mongodb-aws-test.sh - - assume-role - - "run aws auth test with aws EC2 credentials": - - command: subprocess.exec - params: - include_expansions_in_env: [ "TEST_DATA_LAKE", "TEST_INDEX_MANAGEMENT", "CRYPT_SHARED_LIB_PATH", "test_encryption", "test_encryption_pyopenssl", "test_crypt_shared", "test_pyopenssl", "test_loadbalancer", "test_serverless", "ORCHESTRATION_FILE" ] - binary: bash - working_dir: "src" - args: - - .evergreen/scripts/setup-tests.sh - - command: subprocess.exec - type: test - params: - include_expansions_in_env: ["DRIVERS_TOOLS", "skip_EC2_auth_test"] - binary: bash - working_dir: "src" - args: - - .evergreen/scripts/run-with-env.sh - - .evergreen/scripts/run-mongodb-aws-test.sh - - ec2 - - "run aws auth test with aws web identity credentials": - - command: subprocess.exec - params: - include_expansions_in_env: [ "TEST_DATA_LAKE", "TEST_INDEX_MANAGEMENT", "CRYPT_SHARED_LIB_PATH", "test_encryption", "test_encryption_pyopenssl", "test_crypt_shared", "test_pyopenssl", "test_loadbalancer", "test_serverless", "ORCHESTRATION_FILE" ] - binary: bash - working_dir: "src" - args: - - .evergreen/scripts/setup-tests.sh - - # Test with and without AWS_ROLE_SESSION_NAME set. - - command: subprocess.exec - type: test - params: - include_expansions_in_env: ["DRIVERS_TOOLS", "skip_EC2_auth_test"] - binary: bash - working_dir: "src" - args: - - .evergreen/scripts/run-with-env.sh - - .evergreen/scripts/run-mongodb-aws-test.sh - - web-identity - - command: subprocess.exec - type: test - params: - include_expansions_in_env: [ "DRIVERS_TOOLS", "skip_EC2_auth_test" ] - binary: bash - working_dir: "src" - env: - AWS_ROLE_SESSION_NAME: test - args: - - .evergreen/scripts/run-with-env.sh - - .evergreen/scripts/run-mongodb-aws-test.sh - - web-identity - - "run aws auth test with aws credentials as environment variables": - - command: subprocess.exec - params: - include_expansions_in_env: [ "TEST_DATA_LAKE", "TEST_INDEX_MANAGEMENT", "CRYPT_SHARED_LIB_PATH", "test_encryption", "test_encryption_pyopenssl", "test_crypt_shared", "test_pyopenssl", "test_loadbalancer", "test_serverless", "ORCHESTRATION_FILE" ] - binary: bash - working_dir: "src" - args: - - .evergreen/scripts/setup-tests.sh - - command: subprocess.exec - type: test - params: - include_expansions_in_env: ["DRIVERS_TOOLS", "skip_EC2_auth_test"] - binary: bash - working_dir: "src" - args: - - .evergreen/scripts/run-with-env.sh - - .evergreen/scripts/run-mongodb-aws-test.sh - - env-creds - - "run aws auth test with aws credentials and session token as environment variables": - - command: subprocess.exec - params: - include_expansions_in_env: [ "TEST_DATA_LAKE", "TEST_INDEX_MANAGEMENT", "CRYPT_SHARED_LIB_PATH", "test_encryption", "test_encryption_pyopenssl", "test_crypt_shared", "test_pyopenssl", "test_loadbalancer", "test_serverless", "ORCHESTRATION_FILE" ] - binary: bash - working_dir: "src" - args: - - .evergreen/scripts/setup-tests.sh - - command: subprocess.exec - type: test - params: - include_expansions_in_env: ["DRIVERS_TOOLS", "skip_EC2_auth_test"] - binary: bash - working_dir: "src" - args: - - .evergreen/scripts/run-with-env.sh - - .evergreen/scripts/run-mongodb-aws-test.sh - - session-creds - - "run oidc auth test with test credentials": - - command: subprocess.exec - params: - include_expansions_in_env: [ "TEST_DATA_LAKE", "TEST_INDEX_MANAGEMENT", "CRYPT_SHARED_LIB_PATH", "test_encryption", "test_encryption_pyopenssl", "test_crypt_shared", "test_pyopenssl", "test_loadbalancer", "test_serverless", "ORCHESTRATION_FILE" ] - binary: bash - working_dir: "src" - args: - - .evergreen/scripts/setup-tests.sh - - command: subprocess.exec - type: test - params: - working_dir: "src" - binary: bash - include_expansions_in_env: ["DRIVERS_TOOLS", "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"] - args: - - .evergreen/run-mongodb-oidc-test.sh - - "run oidc k8s auth test": - - command: subprocess.exec - type: test - params: - binary: bash - working_dir: src - env: - OIDC_ENV: k8s - include_expansions_in_env: ["DRIVERS_TOOLS", "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN", "K8S_VARIANT"] - args: - - ${PROJECT_DIRECTORY}/.evergreen/run-mongodb-oidc-remote-test.sh - - "run aws ECS auth test": - - command: subprocess.exec - type: test - params: - binary: bash - working_dir: "src" - args: - - .evergreen/scripts/run-with-env.sh - - .evergreen/scripts/run-aws-ecs-auth-test.sh + args: [.evergreen/just.sh, run-tests] "cleanup": - command: subprocess.exec @@ -511,29 +106,14 @@ functions: binary: bash working_dir: "src" args: - - .evergreen/scripts/run-with-env.sh - .evergreen/scripts/cleanup.sh "teardown system": - command: subprocess.exec params: - binary: bash - working_dir: "src" - args: - # Ensure the instance profile is reassigned for aws tests. - - ${DRIVERS_TOOLS}/.evergreen/auth_aws/teardown.sh - - command: subprocess.exec - params: - binary: bash - working_dir: "src" - args: - - ${DRIVERS_TOOLS}/.evergreen/csfle/teardown.sh - - command: subprocess.exec - params: - binary: bash - working_dir: "src" - args: - - ${DRIVERS_TOOLS}/.evergreen/ocsp/teardown.sh + binary: bash + working_dir: "src" + args: [.evergreen/just.sh, teardown-tests] - command: subprocess.exec params: binary: bash @@ -545,83 +125,7 @@ functions: - command: ec2.assume_role params: role_arn: ${aws_test_secrets_role} - - "setup atlas": - - command: subprocess.exec - params: - binary: bash - include_expansions_in_env: ["task_id", "execution"] - env: - MONGODB_VERSION: "7.0" - LAMBDA_STACK_NAME: dbx-python-lambda - args: - - ${DRIVERS_TOOLS}/.evergreen/atlas/setup-atlas-cluster.sh - - command: expansions.update - params: - file: atlas-expansion.yml - - "run-ocsp-test": - - command: subprocess.exec - params: - include_expansions_in_env: [ "TEST_DATA_LAKE", "TEST_INDEX_MANAGEMENT", "CRYPT_SHARED_LIB_PATH", "test_encryption", "test_encryption_pyopenssl", "test_crypt_shared", "test_pyopenssl", "test_loadbalancer", "test_serverless", "ORCHESTRATION_FILE" ] - binary: bash - working_dir: "src" - args: - - .evergreen/scripts/setup-tests.sh - - command: subprocess.exec - type: test - params: - include_expansions_in_env: ["OCSP_ALGORITHM", "OCSP_TLS_SHOULD_SUCCEED", "PYTHON_BINARY"] - binary: bash - working_dir: "src" - args: - - .evergreen/scripts/run-with-env.sh - - .evergreen/scripts/run-ocsp-test.sh - - "run-ocsp-server": - - command: subprocess.exec - params: - background: true - binary: bash - include_expansions_in_env: [SERVER_TYPE, OCSP_ALGORITHM] - args: - - ${DRIVERS_TOOLS}/.evergreen/ocsp/setup.sh - - "run load-balancer": - - command: subprocess.exec - params: - binary: bash - include_expansions_in_env: ["MONGODB_URI"] - args: - - src/.evergreen/scripts/run-with-env.sh - - src/.evergreen/scripts/run-load-balancer.sh - - command: expansions.update - params: - file: lb-expansion.yml - - "stop load-balancer": - - command: subprocess.exec - params: - binary: bash - args: - - src/.evergreen/scripts/stop-load-balancer.sh - - "teardown atlas": - - command: subprocess.exec - params: - binary: bash - args: - - ${DRIVERS_TOOLS}/.evergreen/atlas/teardown-atlas-cluster.sh - - "run perf tests": - - command: subprocess.exec - type: test - params: - working_dir: "src" - binary: bash - args: - - .evergreen/scripts/run-with-env.sh - - .evergreen/scripts/run-perf-tests.sh + duration_seconds: 3600 "attach benchmark test results": - command: attach.results @@ -645,1039 +149,4 @@ post: - func: "upload coverage" - func: "upload mo artifacts" - func: "upload test results" - - func: "stop mongo-orchestration" - func: "cleanup" - -task_groups: - - name: serverless_task_group - setup_group_can_fail_task: true - setup_group_timeout_secs: 1800 # 30 minutes - setup_group: - - func: "fetch source" - - func: "setup system" - - command: subprocess.exec - params: - binary: bash - env: - VAULT_NAME: ${VAULT_NAME} - args: - - ${DRIVERS_TOOLS}/.evergreen/serverless/create-instance.sh - teardown_task: - - command: subprocess.exec - params: - binary: bash - args: - - ${DRIVERS_TOOLS}/.evergreen/serverless/delete-instance.sh - - func: "upload test results" - tasks: - - ".serverless" - - - name: testgcpkms_task_group - setup_group_can_fail_task: true - setup_group_timeout_secs: 1800 # 30 minutes - setup_group: - - func: fetch source - - func: setup system - - command: subprocess.exec - params: - binary: bash - args: - - ${DRIVERS_TOOLS}/.evergreen/csfle/gcpkms/create-and-setup-instance.sh - teardown_task: - - command: subprocess.exec - params: - binary: bash - args: - - ${DRIVERS_TOOLS}/.evergreen/csfle/gcpkms/delete-instance.sh - - func: "upload test results" - tasks: - - testgcpkms-task - - - name: testazurekms_task_group - setup_group: - - func: fetch source - - func: setup system - - command: subprocess.exec - params: - binary: bash - env: - AZUREKMS_VMNAME_PREFIX: "PYTHON_DRIVER" - args: - - ${DRIVERS_TOOLS}/.evergreen/csfle/azurekms/create-and-setup-vm.sh - teardown_group: - - command: subprocess.exec - params: - binary: bash - args: - - ${DRIVERS_TOOLS}/.evergreen/csfle/azurekms/delete-vm.sh - - func: "upload test results" - setup_group_can_fail_task: true - teardown_task_can_fail_task: true - setup_group_timeout_secs: 1800 - tasks: - - testazurekms-task - - - name: testazureoidc_task_group - setup_group: - - func: fetch source - - func: setup system - - command: subprocess.exec - params: - binary: bash - env: - AZUREOIDC_VMNAME_PREFIX: "PYTHON_DRIVER" - args: - - ${DRIVERS_TOOLS}/.evergreen/auth_oidc/azure/create-and-setup-vm.sh - teardown_task: - - command: subprocess.exec - params: - binary: bash - args: - - ${DRIVERS_TOOLS}/.evergreen/auth_oidc/azure/delete-vm.sh - setup_group_can_fail_task: true - setup_group_timeout_secs: 1800 - tasks: - - oidc-auth-test-azure - - - name: testgcpoidc_task_group - setup_group: - - func: fetch source - - func: setup system - - command: subprocess.exec - params: - binary: bash - env: - GCPOIDC_VMNAME_PREFIX: "PYTHON_DRIVER" - args: - - ${DRIVERS_TOOLS}/.evergreen/auth_oidc/gcp/setup.sh - teardown_task: - - command: subprocess.exec - params: - binary: bash - args: - - ${DRIVERS_TOOLS}/.evergreen/auth_oidc/gcp/teardown.sh - setup_group_can_fail_task: true - setup_group_timeout_secs: 1800 - tasks: - - oidc-auth-test-gcp - - - name: testk8soidc_task_group - setup_group: - - func: fetch source - - func: setup system - - command: ec2.assume_role - params: - role_arn: ${aws_test_secrets_role} - duration_seconds: 1800 - - command: subprocess.exec - params: - binary: bash - args: - - ${DRIVERS_TOOLS}/.evergreen/auth_oidc/k8s/setup.sh - teardown_task: - - command: subprocess.exec - params: - binary: bash - args: - - ${DRIVERS_TOOLS}/.evergreen/auth_oidc/k8s/teardown.sh - setup_group_can_fail_task: true - setup_group_timeout_secs: 1800 - tasks: - - oidc-auth-test-k8s - - - name: testoidc_task_group - setup_group: - - func: fetch source - - func: setup system - - func: "assume ec2 role" - - command: subprocess.exec - params: - binary: bash - include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"] - args: - - ${DRIVERS_TOOLS}/.evergreen/auth_oidc/setup.sh - teardown_task: - - command: subprocess.exec - params: - binary: bash - args: - - ${DRIVERS_TOOLS}/.evergreen/auth_oidc/teardown.sh - setup_group_can_fail_task: true - setup_group_timeout_secs: 1800 - tasks: - - oidc-auth-test - - - name: test_aws_lambda_task_group - setup_group: - - func: fetch source - - func: setup system - - func: setup atlas - teardown_task: - - func: teardown atlas - setup_group_can_fail_task: true - setup_group_timeout_secs: 1800 - tasks: - - test-aws-lambda-deployed - - - name: test_atlas_task_group_search_indexes - setup_group: - - func: fetch source - - func: setup system - - func: setup atlas - teardown_task: - - func: teardown atlas - setup_group_can_fail_task: true - setup_group_timeout_secs: 1800 - tasks: - - test-search-index-helpers - -tasks: - # Wildcard task. Do you need to find out what tools are available and where? - # Throw it here, and execute this task on all buildvariants - - name: getdata - commands: - - command: subprocess.exec - binary: bash - type: test - params: - args: - - src/.evergreen/scripts/run-getdata.sh -# Standard test tasks {{{ - - - name: "mockupdb" - tags: ["mockupdb"] - commands: - - func: "run mockupdb tests" - - - name: "doctests" - tags: ["doctests"] - commands: - - func: "bootstrap mongo-orchestration" - vars: - VERSION: "latest" - TOPOLOGY: "server" - - func: "run doctests" - - - name: "test-serverless" - tags: ["serverless"] - commands: - - func: "run tests" - - - name: "test-enterprise-auth" - tags: ["enterprise-auth"] - commands: - - func: "bootstrap mongo-orchestration" - vars: - VERSION: "latest" - TOPOLOGY: "server" - - func: "assume ec2 role" - - func: "run enterprise auth tests" - - - name: "test-search-index-helpers" - commands: - - func: "bootstrap mongo-orchestration" - vars: - VERSION: "6.0" - TOPOLOGY: "replica_set" - - func: "run tests" - vars: - TEST_INDEX_MANAGEMENT: "1" - - - name: "mod-wsgi-standalone" - tags: ["mod_wsgi"] - commands: - - func: "bootstrap mongo-orchestration" - vars: - VERSION: "latest" - TOPOLOGY: "server" - - func: "run mod_wsgi tests" - - - name: "mod-wsgi-replica-set" - tags: ["mod_wsgi"] - commands: - - func: "bootstrap mongo-orchestration" - vars: - VERSION: "latest" - TOPOLOGY: "replica_set" - - func: "run mod_wsgi tests" - - - name: "mod-wsgi-embedded-mode-standalone" - tags: ["mod_wsgi"] - commands: - - func: "bootstrap mongo-orchestration" - vars: - VERSION: "latest" - TOPOLOGY: "server" - - func: "run mod_wsgi tests" - vars: - MOD_WSGI_EMBEDDED: "1" - - - name: "mod-wsgi-embedded-mode-replica-set" - tags: ["mod_wsgi"] - commands: - - func: "bootstrap mongo-orchestration" - vars: - VERSION: "latest" - TOPOLOGY: "replica_set" - - func: "run mod_wsgi tests" - vars: - MOD_WSGI_EMBEDDED: "1" - - - name: "no-server" - tags: ["no-server"] - commands: - - func: "run tests" - - - name: "free-threading" - tags: ["free-threading"] - commands: - - func: "bootstrap mongo-orchestration" - vars: - VERSION: "8.0" - TOPOLOGY: "replica_set" - - func: "run direct tests" - - - name: "atlas-connect" - tags: ["atlas-connect"] - commands: - - func: "assume ec2 role" - - func: "run atlas tests" - - - name: atlas-data-lake-tests - commands: - - func: "bootstrap data lake" - - func: "run tests" - vars: - TEST_DATA_LAKE: "true" - - - name: "test-aws-lambda-deployed" - commands: - - command: ec2.assume_role - params: - role_arn: ${LAMBDA_AWS_ROLE_ARN} - duration_seconds: 3600 - - command: subprocess.exec - params: - working_dir: src - binary: bash - add_expansions_to_env: true - args: - - .evergreen/run-deployed-lambda-aws-tests.sh - env: - TEST_LAMBDA_DIRECTORY: ${PROJECT_DIRECTORY}/test/lambda - - - name: test-ocsp-rsa-valid-cert-server-staples - tags: ["ocsp", "ocsp-rsa", "ocsp-staple"] - commands: - - func: run-ocsp-server - vars: - OCSP_ALGORITHM: "rsa" - SERVER_TYPE: "valid" - - func: "bootstrap mongo-orchestration" - vars: - ORCHESTRATION_FILE: "rsa-basic-tls-ocsp-mustStaple.json" - - func: run-ocsp-test - vars: - OCSP_ALGORITHM: "rsa" - OCSP_TLS_SHOULD_SUCCEED: "true" - - - name: test-ocsp-rsa-invalid-cert-server-staples - tags: ["ocsp", "ocsp-rsa", "ocsp-staple"] - commands: - - func: run-ocsp-server - vars: - OCSP_ALGORITHM: "rsa" - SERVER_TYPE: "revoked" - - func: "bootstrap mongo-orchestration" - vars: - ORCHESTRATION_FILE: "rsa-basic-tls-ocsp-mustStaple.json" - - func: run-ocsp-test - vars: - OCSP_ALGORITHM: "rsa" - OCSP_TLS_SHOULD_SUCCEED: "false" - - - name: test-ocsp-rsa-valid-cert-server-does-not-staple - tags: ["ocsp", "ocsp-rsa"] - commands: - - func: run-ocsp-server - vars: - OCSP_ALGORITHM: "rsa" - SERVER_TYPE: valid - - func: "bootstrap mongo-orchestration" - vars: - ORCHESTRATION_FILE: "rsa-basic-tls-ocsp-disableStapling.json" - - func: run-ocsp-test - vars: - OCSP_ALGORITHM: "rsa" - OCSP_TLS_SHOULD_SUCCEED: "true" - - - name: test-ocsp-rsa-invalid-cert-server-does-not-staple - tags: ["ocsp", "ocsp-rsa"] - commands: - - func: run-ocsp-server - vars: - OCSP_ALGORITHM: "rsa" - SERVER_TYPE: revoked - - func: "bootstrap mongo-orchestration" - vars: - ORCHESTRATION_FILE: "rsa-basic-tls-ocsp-disableStapling.json" - - func: run-ocsp-test - vars: - OCSP_ALGORITHM: "rsa" - OCSP_TLS_SHOULD_SUCCEED: "false" - - - name: test-ocsp-rsa-soft-fail - tags: ["ocsp", "ocsp-rsa"] - commands: - - func: "bootstrap mongo-orchestration" - vars: - ORCHESTRATION_FILE: "rsa-basic-tls-ocsp-disableStapling.json" - - func: run-ocsp-test - vars: - OCSP_ALGORITHM: "rsa" - OCSP_TLS_SHOULD_SUCCEED: "true" - - - name: test-ocsp-rsa-malicious-invalid-cert-mustStaple-server-does-not-staple - tags: ["ocsp", "ocsp-rsa"] - commands: - - func: run-ocsp-server - vars: - OCSP_ALGORITHM: "rsa" - SERVER_TYPE: revoked - - func: "bootstrap mongo-orchestration" - vars: - ORCHESTRATION_FILE: "rsa-basic-tls-ocsp-mustStaple-disableStapling.json" - - func: run-ocsp-test - vars: - OCSP_ALGORITHM: "rsa" - OCSP_TLS_SHOULD_SUCCEED: "false" - - - name: test-ocsp-rsa-malicious-no-responder-mustStaple-server-does-not-staple - tags: ["ocsp", "ocsp-rsa"] - commands: - - func: "bootstrap mongo-orchestration" - vars: - ORCHESTRATION_FILE: "rsa-basic-tls-ocsp-mustStaple-disableStapling.json" - - func: run-ocsp-test - vars: - OCSP_ALGORITHM: "rsa" - OCSP_TLS_SHOULD_SUCCEED: "false" - - - name: test-ocsp-rsa-delegate-valid-cert-server-staples - tags: ["ocsp", "ocsp-rsa", "ocsp-staple"] - commands: - - func: run-ocsp-server - vars: - OCSP_ALGORITHM: "rsa" - SERVER_TYPE: valid-delegate - - func: "bootstrap mongo-orchestration" - vars: - ORCHESTRATION_FILE: "rsa-basic-tls-ocsp-mustStaple.json" - - func: run-ocsp-test - vars: - OCSP_ALGORITHM: "rsa" - OCSP_TLS_SHOULD_SUCCEED: "true" - - - name: test-ocsp-rsa-delegate-invalid-cert-server-staples - tags: ["ocsp", "ocsp-rsa", "ocsp-staple"] - commands: - - func: run-ocsp-server - vars: - OCSP_ALGORITHM: "rsa" - SERVER_TYPE: revoked-delegate - - func: "bootstrap mongo-orchestration" - vars: - ORCHESTRATION_FILE: "rsa-basic-tls-ocsp-mustStaple.json" - - func: run-ocsp-test - vars: - OCSP_ALGORITHM: "rsa" - OCSP_TLS_SHOULD_SUCCEED: "false" - - - name: test-ocsp-rsa-delegate-valid-cert-server-does-not-staple - tags: ["ocsp", "ocsp-rsa"] - commands: - - func: run-ocsp-server - vars: - OCSP_ALGORITHM: "rsa" - SERVER_TYPE: valid-delegate - - func: "bootstrap mongo-orchestration" - vars: - ORCHESTRATION_FILE: "rsa-basic-tls-ocsp-disableStapling.json" - - func: run-ocsp-test - vars: - OCSP_ALGORITHM: "rsa" - OCSP_TLS_SHOULD_SUCCEED: "true" - - - name: test-ocsp-rsa-delegate-invalid-cert-server-does-not-staple - tags: ["ocsp", "ocsp-rsa"] - commands: - - func: run-ocsp-server - vars: - OCSP_ALGORITHM: "rsa" - SERVER_TYPE: revoked-delegate - - func: "bootstrap mongo-orchestration" - vars: - ORCHESTRATION_FILE: "rsa-basic-tls-ocsp-disableStapling.json" - - func: run-ocsp-test - vars: - OCSP_ALGORITHM: "rsa" - OCSP_TLS_SHOULD_SUCCEED: "false" - - - name: test-ocsp-rsa-delegate-malicious-invalid-cert-mustStaple-server-does-not-staple - tags: ["ocsp", "ocsp-rsa"] - commands: - - func: run-ocsp-server - vars: - OCSP_ALGORITHM: "rsa" - SERVER_TYPE: revoked-delegate - - func: "bootstrap mongo-orchestration" - vars: - ORCHESTRATION_FILE: "rsa-basic-tls-ocsp-mustStaple-disableStapling.json" - - func: run-ocsp-test - vars: - OCSP_ALGORITHM: "rsa" - OCSP_TLS_SHOULD_SUCCEED: "false" - - - name: test-ocsp-ecdsa-valid-cert-server-staples - tags: ["ocsp", "ocsp-ecdsa", "ocsp-staple"] - commands: - - func: run-ocsp-server - vars: - OCSP_ALGORITHM: "ecdsa" - SERVER_TYPE: valid - - func: "bootstrap mongo-orchestration" - vars: - ORCHESTRATION_FILE: "ecdsa-basic-tls-ocsp-mustStaple.json" - - func: run-ocsp-test - vars: - OCSP_ALGORITHM: "ecdsa" - OCSP_TLS_SHOULD_SUCCEED: "true" - - - name: test-ocsp-ecdsa-invalid-cert-server-staples - tags: ["ocsp", "ocsp-ecdsa", "ocsp-staple"] - commands: - - func: run-ocsp-server - vars: - OCSP_ALGORITHM: "ecdsa" - SERVER_TYPE: revoked - - func: "bootstrap mongo-orchestration" - vars: - ORCHESTRATION_FILE: "ecdsa-basic-tls-ocsp-mustStaple.json" - - func: run-ocsp-test - vars: - OCSP_ALGORITHM: "ecdsa" - OCSP_TLS_SHOULD_SUCCEED: "false" - - - name: test-ocsp-ecdsa-valid-cert-server-does-not-staple - tags: ["ocsp", "ocsp-ecdsa"] - commands: - - func: run-ocsp-server - vars: - OCSP_ALGORITHM: "ecdsa" - SERVER_TYPE: valid - - func: "bootstrap mongo-orchestration" - vars: - ORCHESTRATION_FILE: "ecdsa-basic-tls-ocsp-disableStapling.json" - - func: run-ocsp-test - vars: - OCSP_ALGORITHM: "ecdsa" - OCSP_TLS_SHOULD_SUCCEED: "true" - - - name: test-ocsp-ecdsa-invalid-cert-server-does-not-staple - tags: ["ocsp", "ocsp-ecdsa"] - commands: - - func: run-ocsp-server - vars: - OCSP_ALGORITHM: "ecdsa" - SERVER_TYPE: revoked - - func: "bootstrap mongo-orchestration" - vars: - ORCHESTRATION_FILE: "ecdsa-basic-tls-ocsp-disableStapling.json" - - func: run-ocsp-test - vars: - OCSP_ALGORITHM: "ecdsa" - OCSP_TLS_SHOULD_SUCCEED: "false" - - - name: test-ocsp-ecdsa-soft-fail - tags: ["ocsp", "ocsp-ecdsa"] - commands: - - func: "bootstrap mongo-orchestration" - vars: - ORCHESTRATION_FILE: "ecdsa-basic-tls-ocsp-disableStapling.json" - - func: run-ocsp-test - vars: - OCSP_ALGORITHM: "ecdsa" - OCSP_TLS_SHOULD_SUCCEED: "true" - - - name: test-ocsp-ecdsa-malicious-invalid-cert-mustStaple-server-does-not-staple - tags: ["ocsp", "ocsp-ecdsa"] - commands: - - func: run-ocsp-server - vars: - OCSP_ALGORITHM: "ecdsa" - SERVER_TYPE: revoked - - func: "bootstrap mongo-orchestration" - vars: - ORCHESTRATION_FILE: "ecdsa-basic-tls-ocsp-mustStaple-disableStapling.json" - - func: run-ocsp-test - vars: - OCSP_ALGORITHM: "ecdsa" - OCSP_TLS_SHOULD_SUCCEED: "false" - - - name: test-ocsp-ecdsa-malicious-no-responder-mustStaple-server-does-not-staple - tags: ["ocsp", "ocsp-ecdsa"] - commands: - - func: "bootstrap mongo-orchestration" - vars: - ORCHESTRATION_FILE: "ecdsa-basic-tls-ocsp-mustStaple-disableStapling.json" - - func: run-ocsp-test - vars: - OCSP_ALGORITHM: "ecdsa" - OCSP_TLS_SHOULD_SUCCEED: "false" - - - name: test-ocsp-ecdsa-delegate-valid-cert-server-staples - tags: ["ocsp", "ocsp-ecdsa", "ocsp-staple"] - commands: - - func: run-ocsp-server - vars: - OCSP_ALGORITHM: "ecdsa" - SERVER_TYPE: valid-delegate - - func: "bootstrap mongo-orchestration" - vars: - ORCHESTRATION_FILE: "ecdsa-basic-tls-ocsp-mustStaple.json" - - func: run-ocsp-test - vars: - OCSP_ALGORITHM: "ecdsa" - OCSP_TLS_SHOULD_SUCCEED: "true" - - - name: test-ocsp-ecdsa-delegate-invalid-cert-server-staples - tags: ["ocsp", "ocsp-ecdsa", "ocsp-staple"] - commands: - - func: run-ocsp-server - vars: - OCSP_ALGORITHM: "ecdsa" - SERVER_TYPE: revoked-delegate - - func: "bootstrap mongo-orchestration" - vars: - ORCHESTRATION_FILE: "ecdsa-basic-tls-ocsp-mustStaple.json" - - func: run-ocsp-test - vars: - OCSP_ALGORITHM: "ecdsa" - OCSP_TLS_SHOULD_SUCCEED: "false" - - - name: test-ocsp-ecdsa-delegate-valid-cert-server-does-not-staple - tags: ["ocsp", "ocsp-ecdsa"] - commands: - - func: run-ocsp-server - vars: - OCSP_ALGORITHM: "ecdsa" - SERVER_TYPE: valid-delegate - - func: "bootstrap mongo-orchestration" - vars: - ORCHESTRATION_FILE: "ecdsa-basic-tls-ocsp-disableStapling.json" - - func: run-ocsp-test - vars: - OCSP_ALGORITHM: "ecdsa" - OCSP_TLS_SHOULD_SUCCEED: "true" - - - name: test-ocsp-ecdsa-delegate-invalid-cert-server-does-not-staple - tags: ["ocsp", "ocsp-ecdsa"] - commands: - - func: run-ocsp-server - vars: - OCSP_ALGORITHM: "ecdsa" - SERVER_TYPE: revoked-delegate - - func: "bootstrap mongo-orchestration" - vars: - ORCHESTRATION_FILE: "ecdsa-basic-tls-ocsp-disableStapling.json" - - func: run-ocsp-test - vars: - OCSP_ALGORITHM: "ecdsa" - OCSP_TLS_SHOULD_SUCCEED: "false" - - - name: test-ocsp-ecdsa-delegate-malicious-invalid-cert-mustStaple-server-does-not-staple - tags: ["ocsp", "ocsp-ecdsa"] - commands: - - func: run-ocsp-server - vars: - OCSP_ALGORITHM: "ecdsa" - SERVER_TYPE: valid-delegate - - func: "bootstrap mongo-orchestration" - vars: - ORCHESTRATION_FILE: "ecdsa-basic-tls-ocsp-mustStaple-disableStapling.json" - - func: run-ocsp-test - vars: - OCSP_ALGORITHM: "ecdsa" - OCSP_TLS_SHOULD_SUCCEED: "false" - - - name: "aws-auth-test-4.4" - commands: - - func: "bootstrap mongo-orchestration" - vars: - AUTH: "auth" - ORCHESTRATION_FILE: "auth-aws.json" - TOPOLOGY: "server" - VERSION: "4.4" - - func: "assume ec2 role" - - func: "get aws auth secrets" - - func: "run aws auth test with regular aws credentials" - - func: "run aws auth test with assume role credentials" - - func: "run aws auth test with aws credentials as environment variables" - - func: "run aws auth test with aws credentials and session token as environment variables" - - func: "run aws auth test with aws EC2 credentials" - - func: "run aws auth test with aws web identity credentials" - - func: "run aws ECS auth test" - - - name: "aws-auth-test-5.0" - commands: - - func: "bootstrap mongo-orchestration" - vars: - AUTH: "auth" - ORCHESTRATION_FILE: "auth-aws.json" - TOPOLOGY: "server" - VERSION: "5.0" - - func: "assume ec2 role" - - func: "get aws auth secrets" - - func: "run aws auth test with regular aws credentials" - - func: "run aws auth test with assume role credentials" - - func: "run aws auth test with aws credentials as environment variables" - - func: "run aws auth test with aws credentials and session token as environment variables" - - func: "run aws auth test with aws EC2 credentials" - - func: "run aws auth test with aws web identity credentials" - - func: "run aws ECS auth test" - - - name: "aws-auth-test-6.0" - commands: - - func: "bootstrap mongo-orchestration" - vars: - AUTH: "auth" - ORCHESTRATION_FILE: "auth-aws.json" - TOPOLOGY: "server" - VERSION: "6.0" - - func: "assume ec2 role" - - func: "get aws auth secrets" - - func: "run aws auth test with regular aws credentials" - - func: "run aws auth test with assume role credentials" - - func: "run aws auth test with aws credentials as environment variables" - - func: "run aws auth test with aws credentials and session token as environment variables" - - func: "run aws auth test with aws EC2 credentials" - - func: "run aws auth test with aws web identity credentials" - - func: "run aws ECS auth test" - - - name: "aws-auth-test-7.0" - commands: - - func: "bootstrap mongo-orchestration" - vars: - AUTH: "auth" - ORCHESTRATION_FILE: "auth-aws.json" - TOPOLOGY: "server" - VERSION: "7.0" - - func: "assume ec2 role" - - func: "get aws auth secrets" - - func: "run aws auth test with regular aws credentials" - - func: "run aws auth test with assume role credentials" - - func: "run aws auth test with aws credentials as environment variables" - - func: "run aws auth test with aws credentials and session token as environment variables" - - func: "run aws auth test with aws EC2 credentials" - - func: "run aws auth test with aws web identity credentials" - - func: "run aws ECS auth test" - - - name: "aws-auth-test-8.0" - commands: - - func: "bootstrap mongo-orchestration" - vars: - AUTH: "auth" - ORCHESTRATION_FILE: "auth-aws.json" - TOPOLOGY: "server" - VERSION: "8.0" - - func: "assume ec2 role" - - func: "get aws auth secrets" - - func: "run aws auth test with regular aws credentials" - - func: "run aws auth test with assume role credentials" - - func: "run aws auth test with aws credentials as environment variables" - - func: "run aws auth test with aws credentials and session token as environment variables" - - func: "run aws auth test with aws EC2 credentials" - - func: "run aws auth test with aws web identity credentials" - - func: "run aws ECS auth test" - - - name: "aws-auth-test-rapid" - commands: - - func: "bootstrap mongo-orchestration" - vars: - AUTH: "auth" - ORCHESTRATION_FILE: "auth-aws.json" - TOPOLOGY: "server" - VERSION: "rapid" - - func: "assume ec2 role" - - func: "get aws auth secrets" - - func: "run aws auth test with regular aws credentials" - - func: "run aws auth test with assume role credentials" - - func: "run aws auth test with aws credentials as environment variables" - - func: "run aws auth test with aws credentials and session token as environment variables" - - func: "run aws auth test with aws EC2 credentials" - - func: "run aws auth test with aws web identity credentials" - - func: "run aws ECS auth test" - - - name: "aws-auth-test-latest" - commands: - - func: "bootstrap mongo-orchestration" - vars: - AUTH: "auth" - ORCHESTRATION_FILE: "auth-aws.json" - TOPOLOGY: "server" - VERSION: "latest" - - func: "assume ec2 role" - - func: "get aws auth secrets" - - func: "run aws auth test with regular aws credentials" - - func: "run aws auth test with assume role credentials" - - func: "run aws auth test with aws credentials as environment variables" - - func: "run aws auth test with aws credentials and session token as environment variables" - - func: "run aws auth test with aws EC2 credentials" - - func: "run aws auth test with aws web identity credentials" - - func: "run aws ECS auth test" - - - name: "oidc-auth-test" - commands: - - func: "run oidc auth test with test credentials" - - - name: "oidc-auth-test-azure" - commands: - - command: subprocess.exec - type: test - params: - binary: bash - working_dir: src - env: - OIDC_ENV: azure - include_expansions_in_env: ["DRIVERS_TOOLS"] - args: - - ${PROJECT_DIRECTORY}/.evergreen/run-mongodb-oidc-remote-test.sh - - - name: "oidc-auth-test-gcp" - commands: - - command: subprocess.exec - type: test - params: - binary: bash - working_dir: src - env: - OIDC_ENV: gcp - include_expansions_in_env: ["DRIVERS_TOOLS"] - args: - - ${PROJECT_DIRECTORY}/.evergreen/run-mongodb-oidc-remote-test.sh - - - name: "oidc-auth-test-k8s" - commands: - - func: "run oidc k8s auth test" - vars: - K8S_VARIANT: eks - - func: "run oidc k8s auth test" - vars: - K8S_VARIANT: gke - - func: "run oidc k8s auth test" - vars: - K8S_VARIANT: aks -# }}} - - name: "coverage-report" - tags: ["coverage"] - depends_on: - # BUILD-3165: We can't use "*" (all tasks) and specify "variant". - # Instead list out all coverage tasks using tags. - - name: ".standalone" - variant: ".coverage_tag" - # Run the coverage task even if some tasks fail. - status: "*" - # Run the coverage task even if some tasks are not scheduled in a patch build. - patch_optional: true - - name: ".replica_set" - variant: ".coverage_tag" - status: "*" - patch_optional: true - - name: ".sharded_cluster" - variant: ".coverage_tag" - status: "*" - patch_optional: true - commands: - - func: "download and merge coverage" - - - name: "testgcpkms-task" - commands: - - command: subprocess.exec - type: setup - params: - working_dir: "src" - binary: bash - include_expansions_in_env: ["DRIVERS_TOOLS"] - args: - - .evergreen/run-gcpkms-test.sh - - - name: "testgcpkms-fail-task" - # testgcpkms-fail-task runs in a non-GCE environment. - # It is expected to fail to obtain GCE credentials. - commands: - - func: "bootstrap mongo-orchestration" - vars: - VERSION: "latest" - TOPOLOGY: "server" - - command: subprocess.exec - type: test - params: - include_expansions_in_env: ["PYTHON_BINARY"] - working_dir: "src" - binary: "bash" - args: - - .evergreen/scripts/run-gcpkms-fail-test.sh - - - name: testazurekms-task - commands: - - command: subprocess.exec - params: - binary: bash - working_dir: src - include_expansions_in_env: ["DRIVERS_TOOLS"] - args: - - .evergreen/run-azurekms-test.sh - - - name: testazurekms-fail-task - commands: - - func: "bootstrap mongo-orchestration" - vars: - VERSION: "latest" - TOPOLOGY: "server" - - command: subprocess.exec - type: test - params: - binary: bash - working_dir: src - include_expansions_in_env: ["DRIVERS_TOOLS"] - args: - - .evergreen/run-azurekms-fail-test.sh - - - name: "perf-6.0-standalone" - tags: ["perf"] - commands: - - func: "bootstrap mongo-orchestration" - vars: - VERSION: "v6.0-perf" - TOPOLOGY: "server" - - func: "run perf tests" - - func: "attach benchmark test results" - - func: "send dashboard data" - - - name: "perf-6.0-standalone-ssl" - tags: ["perf"] - commands: - - func: "bootstrap mongo-orchestration" - vars: - VERSION: "v6.0-perf" - TOPOLOGY: "server" - SSL: "ssl" - - func: "run perf tests" - - func: "attach benchmark test results" - - func: "send dashboard data" - - - name: "perf-8.0-standalone" - tags: ["perf"] - commands: - - func: "bootstrap mongo-orchestration" - vars: - VERSION: "8.0" - TOPOLOGY: "server" - - func: "run perf tests" - - func: "attach benchmark test results" - - func: "send dashboard data" - - - name: "check-import-time" - tags: ["pr"] - commands: - - command: subprocess.exec - type: test - params: - binary: bash - working_dir: src - include_expansions_in_env: ["PYTHON_BINARY"] - args: - - .evergreen/scripts/check-import-time.sh - - ${revision} - - ${github_commit} - - name: "backport-pr" - allowed_requesters: ["commit"] - commands: - - command: subprocess.exec - type: test - params: - binary: bash - args: - - ${DRIVERS_TOOLS}/.evergreen/github_app/backport-pr.sh - - mongodb - - mongo-python-driver - - ${github_commit} - -buildvariants: -- name: "no-server" - display_name: "No server" - run_on: - - rhel84-small - tasks: - - name: "no-server" - -- name: "Coverage Report" - display_name: "Coverage Report" - run_on: - - rhel84-small - tasks: - - name: "coverage-report" - -- name: testkms-variant - display_name: "KMS" - run_on: - - debian11-small - tasks: - - name: testgcpkms_task_group - batchtime: 20160 # Use a batchtime of 14 days as suggested by the CSFLE test README - - testgcpkms-fail-task - - name: testazurekms_task_group - batchtime: 20160 # Use a batchtime of 14 days as suggested by the CSFLE test README - - testazurekms-fail-task - -- name: rhel8-test-lambda - display_name: FaaS Lambda - run_on: rhel87-small - tasks: - - name: test_aws_lambda_task_group - -- name: rhel8-import-time - display_name: Import Time - run_on: rhel87-small - tasks: - - name: "check-import-time" - -- name: backport-pr - display_name: "Backport PR" - run_on: - - rhel8.7-small - tasks: - - name: "backport-pr" - -- name: "perf-tests" - display_name: "Performance Benchmarks" - batchtime: 10080 # 7 days - run_on: rhel90-dbx-perf-large - tasks: - - name: "perf-6.0-standalone" - - name: "perf-6.0-standalone-ssl" - - name: "perf-8.0-standalone" - - # Platform notes - # i386 builds of OpenSSL or Cyrus SASL are not available - # Debian 8.1 only supports MongoDB 3.4+ - # SUSE12 s390x is only supported by MongoDB 3.4+ - # No enterprise build for Archlinux, SSL not available - # RHEL 7.6 and RHEL 8.4 only supports 3.6+. - # RHEL 7 only supports 2.6+ - # RHEL 7.1 ppc64le is only supported by MongoDB 3.2+ - # RHEL 7.2 s390x is only supported by MongoDB 3.4+ - # Solaris MongoDB SSL builds are not available - # Darwin MongoDB SSL builds are not available for 2.6 - # SUSE12 x86_64 is only supported by MongoDB 3.2+ - # vim: set et sw=2 ts=2 : diff --git a/.evergreen/generated_configs/functions.yml b/.evergreen/generated_configs/functions.yml new file mode 100644 index 0000000000..afd7f11374 --- /dev/null +++ b/.evergreen/generated_configs/functions.yml @@ -0,0 +1,117 @@ +functions: + # Download and merge coverage + download and merge coverage: + - command: ec2.assume_role + params: + role_arn: ${assume_role_arn} + type: setup + - command: subprocess.exec + params: + binary: bash + args: + - .evergreen/scripts/download-and-merge-coverage.sh + - ${bucket_name} + - ${revision} + - ${version_id} + working_dir: src + silent: true + include_expansions_in_env: + - AWS_ACCESS_KEY_ID + - AWS_SECRET_ACCESS_KEY + - AWS_SESSION_TOKEN + type: test + - command: subprocess.exec + params: + binary: bash + args: + - .evergreen/combine-coverage.sh + working_dir: src + type: test + - command: subprocess.exec + params: + binary: bash + args: + - .evergreen/scripts/upload-coverage-report.sh + - ${bucket_name} + - ${revision} + - ${version_id} + working_dir: src + silent: true + include_expansions_in_env: + - AWS_ACCESS_KEY_ID + - AWS_SECRET_ACCESS_KEY + - AWS_SESSION_TOKEN + type: test + - command: s3.put + params: + remote_file: coverage/${revision}/${version_id}/htmlcov/index.html + aws_key: ${AWS_ACCESS_KEY_ID} + aws_secret: ${AWS_SECRET_ACCESS_KEY} + aws_session_token: ${AWS_SESSION_TOKEN} + bucket: ${bucket_name} + local_file: src/htmlcov/index.html + permissions: public-read + content_type: text/html + display_name: Coverage Report HTML + optional: "true" + type: setup + + # Upload coverage + upload coverage: + - command: ec2.assume_role + params: + role_arn: ${assume_role_arn} + type: setup + - command: s3.put + params: + remote_file: coverage/${revision}/${version_id}/coverage/coverage.${build_variant}.${task_name} + aws_key: ${AWS_ACCESS_KEY_ID} + aws_secret: ${AWS_SECRET_ACCESS_KEY} + aws_session_token: ${AWS_SESSION_TOKEN} + bucket: ${bucket_name} + local_file: src/.coverage + permissions: public-read + content_type: text/html + display_name: Raw Coverage Report + optional: "true" + type: setup + + # Upload mo artifacts + upload mo artifacts: + - command: ec2.assume_role + params: + role_arn: ${assume_role_arn} + type: setup + - command: archive.targz_pack + params: + target: mongo-coredumps.tgz + source_dir: ./ + include: + - ./**.core + - ./**.mdmp + - command: s3.put + params: + remote_file: ${build_variant}/${revision}/${version_id}/${build_id}/coredumps/${task_id}-${execution}-mongodb-coredumps.tar.gz + aws_key: ${AWS_ACCESS_KEY_ID} + aws_secret: ${AWS_SECRET_ACCESS_KEY} + aws_session_token: ${AWS_SESSION_TOKEN} + bucket: ${bucket_name} + local_file: mongo-coredumps.tgz + permissions: public-read + content_type: ${content_type|application/x-gzip} + display_name: Core Dumps - Execution + optional: "true" + type: setup + - command: s3.put + params: + remote_file: ${build_variant}/${revision}/${version_id}/${build_id}/logs/${task_id}-${execution}-drivers-tools-logs.tar.gz + aws_key: ${AWS_ACCESS_KEY_ID} + aws_secret: ${AWS_SECRET_ACCESS_KEY} + aws_session_token: ${AWS_SESSION_TOKEN} + bucket: ${bucket_name} + local_file: ${DRIVERS_TOOLS}/.evergreen/test_logs.tar.gz + permissions: public-read + content_type: ${content_type|application/x-gzip} + display_name: drivers-tools-logs.tar.gz + optional: "true" + type: setup diff --git a/.evergreen/generated_configs/tasks.yml b/.evergreen/generated_configs/tasks.yml index c666c6901a..b2b8dc1191 100644 --- a/.evergreen/generated_configs/tasks.yml +++ b/.evergreen/generated_configs/tasks.yml @@ -1,58 +1,3439 @@ tasks: + # Atlas connect tests + - name: test-atlas-connect + commands: + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: atlas_connect + tags: [atlas_connect] + + # Atlas data lake tests + - name: test-atlas-data-lake-without_ext + commands: + - func: run tests + vars: + TEST_NAME: data_lake + NO_EXT: "1" + tags: [atlas_data_lake] + - name: test-atlas-data-lake-with_ext + commands: + - func: run tests + vars: + TEST_NAME: data_lake + tags: [atlas_data_lake] + + # Aws lambda tests + - name: test-aws-lambda-deployed + commands: + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: aws_lambda + tags: [aws_lambda] + + # Aws tests + - name: test-auth-aws-4.4-regular + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "4.4" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: regular + tags: [auth-aws, auth-aws-regular] + - name: test-auth-aws-4.4-assume-role + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "4.4" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: assume-role + tags: [auth-aws, auth-aws-assume-role] + - name: test-auth-aws-4.4-ec2 + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "4.4" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: ec2 + tags: [auth-aws, auth-aws-ec2] + - name: test-auth-aws-4.4-env-creds + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "4.4" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: env-creds + tags: [auth-aws, auth-aws-env-creds] + - name: test-auth-aws-4.4-session-creds + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "4.4" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: session-creds + tags: [auth-aws, auth-aws-session-creds] + - name: test-auth-aws-4.4-web-identity + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "4.4" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: web-identity + tags: [auth-aws, auth-aws-web-identity] + - name: test-auth-aws-4.4-ecs + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "4.4" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: ecs + tags: [auth-aws, auth-aws-ecs] + - name: test-auth-aws-4.4-web-identity-session-name + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "4.4" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: web-identity + AWS_ROLE_SESSION_NAME: test + tags: [auth-aws, auth-aws-web-identity] + - name: test-auth-aws-5.0-regular + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "5.0" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: regular + tags: [auth-aws, auth-aws-regular] + - name: test-auth-aws-5.0-assume-role + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "5.0" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: assume-role + tags: [auth-aws, auth-aws-assume-role] + - name: test-auth-aws-5.0-ec2 + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "5.0" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: ec2 + tags: [auth-aws, auth-aws-ec2] + - name: test-auth-aws-5.0-env-creds + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "5.0" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: env-creds + tags: [auth-aws, auth-aws-env-creds] + - name: test-auth-aws-5.0-session-creds + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "5.0" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: session-creds + tags: [auth-aws, auth-aws-session-creds] + - name: test-auth-aws-5.0-web-identity + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "5.0" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: web-identity + tags: [auth-aws, auth-aws-web-identity] + - name: test-auth-aws-5.0-ecs + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "5.0" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: ecs + tags: [auth-aws, auth-aws-ecs] + - name: test-auth-aws-5.0-web-identity-session-name + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "5.0" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: web-identity + AWS_ROLE_SESSION_NAME: test + tags: [auth-aws, auth-aws-web-identity] + - name: test-auth-aws-6.0-regular + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "6.0" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: regular + tags: [auth-aws, auth-aws-regular] + - name: test-auth-aws-6.0-assume-role + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "6.0" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: assume-role + tags: [auth-aws, auth-aws-assume-role] + - name: test-auth-aws-6.0-ec2 + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "6.0" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: ec2 + tags: [auth-aws, auth-aws-ec2] + - name: test-auth-aws-6.0-env-creds + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "6.0" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: env-creds + tags: [auth-aws, auth-aws-env-creds] + - name: test-auth-aws-6.0-session-creds + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "6.0" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: session-creds + tags: [auth-aws, auth-aws-session-creds] + - name: test-auth-aws-6.0-web-identity + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "6.0" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: web-identity + tags: [auth-aws, auth-aws-web-identity] + - name: test-auth-aws-6.0-ecs + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "6.0" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: ecs + tags: [auth-aws, auth-aws-ecs] + - name: test-auth-aws-6.0-web-identity-session-name + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "6.0" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: web-identity + AWS_ROLE_SESSION_NAME: test + tags: [auth-aws, auth-aws-web-identity] + - name: test-auth-aws-7.0-regular + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "7.0" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: regular + tags: [auth-aws, auth-aws-regular] + - name: test-auth-aws-7.0-assume-role + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "7.0" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: assume-role + tags: [auth-aws, auth-aws-assume-role] + - name: test-auth-aws-7.0-ec2 + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "7.0" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: ec2 + tags: [auth-aws, auth-aws-ec2] + - name: test-auth-aws-7.0-env-creds + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "7.0" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: env-creds + tags: [auth-aws, auth-aws-env-creds] + - name: test-auth-aws-7.0-session-creds + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "7.0" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: session-creds + tags: [auth-aws, auth-aws-session-creds] + - name: test-auth-aws-7.0-web-identity + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "7.0" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: web-identity + tags: [auth-aws, auth-aws-web-identity] + - name: test-auth-aws-7.0-ecs + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "7.0" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: ecs + tags: [auth-aws, auth-aws-ecs] + - name: test-auth-aws-7.0-web-identity-session-name + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "7.0" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: web-identity + AWS_ROLE_SESSION_NAME: test + tags: [auth-aws, auth-aws-web-identity] + - name: test-auth-aws-8.0-regular + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "8.0" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: regular + tags: [auth-aws, auth-aws-regular] + - name: test-auth-aws-8.0-assume-role + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "8.0" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: assume-role + tags: [auth-aws, auth-aws-assume-role] + - name: test-auth-aws-8.0-ec2 + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "8.0" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: ec2 + tags: [auth-aws, auth-aws-ec2] + - name: test-auth-aws-8.0-env-creds + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "8.0" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: env-creds + tags: [auth-aws, auth-aws-env-creds] + - name: test-auth-aws-8.0-session-creds + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "8.0" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: session-creds + tags: [auth-aws, auth-aws-session-creds] + - name: test-auth-aws-8.0-web-identity + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "8.0" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: web-identity + tags: [auth-aws, auth-aws-web-identity] + - name: test-auth-aws-8.0-ecs + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "8.0" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: ecs + tags: [auth-aws, auth-aws-ecs] + - name: test-auth-aws-8.0-web-identity-session-name + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: "8.0" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: web-identity + AWS_ROLE_SESSION_NAME: test + tags: [auth-aws, auth-aws-web-identity] + - name: test-auth-aws-rapid-regular + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: rapid + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: regular + tags: [auth-aws, auth-aws-regular] + - name: test-auth-aws-rapid-assume-role + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: rapid + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: assume-role + tags: [auth-aws, auth-aws-assume-role] + - name: test-auth-aws-rapid-ec2 + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: rapid + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: ec2 + tags: [auth-aws, auth-aws-ec2] + - name: test-auth-aws-rapid-env-creds + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: rapid + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: env-creds + tags: [auth-aws, auth-aws-env-creds] + - name: test-auth-aws-rapid-session-creds + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: rapid + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: session-creds + tags: [auth-aws, auth-aws-session-creds] + - name: test-auth-aws-rapid-web-identity + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: rapid + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: web-identity + tags: [auth-aws, auth-aws-web-identity] + - name: test-auth-aws-rapid-ecs + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: rapid + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: ecs + tags: [auth-aws, auth-aws-ecs] + - name: test-auth-aws-rapid-web-identity-session-name + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: rapid + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: web-identity + AWS_ROLE_SESSION_NAME: test + tags: [auth-aws, auth-aws-web-identity] + - name: test-auth-aws-latest-regular + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: latest + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: regular + tags: [auth-aws, auth-aws-regular] + - name: test-auth-aws-latest-assume-role + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: latest + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: assume-role + tags: [auth-aws, auth-aws-assume-role] + - name: test-auth-aws-latest-ec2 + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: latest + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: ec2 + tags: [auth-aws, auth-aws-ec2] + - name: test-auth-aws-latest-env-creds + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: latest + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: env-creds + tags: [auth-aws, auth-aws-env-creds] + - name: test-auth-aws-latest-session-creds + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: latest + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: session-creds + tags: [auth-aws, auth-aws-session-creds] + - name: test-auth-aws-latest-web-identity + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: latest + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: web-identity + tags: [auth-aws, auth-aws-web-identity] + - name: test-auth-aws-latest-ecs + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: latest + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: ecs + tags: [auth-aws, auth-aws-ecs] + - name: test-auth-aws-latest-web-identity-session-name + commands: + - func: run server + vars: + AUTH_AWS: "1" + VERSION: latest + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: auth_aws + SUB_TEST_NAME: web-identity + AWS_ROLE_SESSION_NAME: test + tags: [auth-aws, auth-aws-web-identity] + + # Backport pr tests + - name: backport-pr + commands: + - command: subprocess.exec + params: + binary: bash + args: + - ${DRIVERS_TOOLS}/.evergreen/github_app/backport-pr.sh + - mongodb + - mongo-python-driver + - ${github_commit} + working_dir: src + type: test + + # Compression tests + - name: test-compression-v4.0-python3.9 + commands: + - func: run server + vars: + VERSION: "4.0" + - func: run tests + tags: [compression, "4.0"] + - name: test-compression-v4.2-python3.9 + commands: + - func: run server + vars: + VERSION: "4.2" + - func: run tests + tags: [compression, "4.2"] + - name: test-compression-v4.4-python3.9 + commands: + - func: run server + vars: + VERSION: "4.4" + - func: run tests + tags: [compression, "4.4"] + - name: test-compression-v5.0-python3.9 + commands: + - func: run server + vars: + VERSION: "5.0" + - func: run tests + tags: [compression, "5.0"] + - name: test-compression-v6.0-python3.9 + commands: + - func: run server + vars: + VERSION: "6.0" + - func: run tests + tags: [compression, "6.0"] + - name: test-compression-v7.0-python3.9 + commands: + - func: run server + vars: + VERSION: "7.0" + - func: run tests + tags: [compression, "7.0"] + - name: test-compression-v8.0-python3.9 + commands: + - func: run server + vars: + VERSION: "8.0" + - func: run tests + tags: [compression, "8.0"] + - name: test-compression-rapid-python3.9 + commands: + - func: run server + vars: + VERSION: rapid + - func: run tests + tags: [compression, rapid] + - name: test-compression-latest-python3.9 + commands: + - func: run server + vars: + VERSION: latest + - func: run tests + tags: [compression, latest] + - name: test-compression-latest-python3.13-no-c + commands: + - func: run server + vars: + VERSION: latest + - func: run tests + vars: + NO_EXT: "1" + tags: [compression, latest] + - name: test-compression-latest-python3.13 + commands: + - func: run server + vars: + VERSION: latest + - func: run tests + vars: {} + tags: [compression, latest] + - name: test-compression-latest-pypy3.10 + commands: + - func: run server + vars: + VERSION: latest + - func: run tests + tags: [compression, latest] + + # Coverage report tests + - name: coverage-report + commands: + - func: download and merge coverage + depends_on: + - name: .standalone + variant: .coverage_tag + status: "*" + patch_optional: true + - name: .replica_set + variant: .coverage_tag + status: "*" + patch_optional: true + - name: .sharded_cluster + variant: .coverage_tag + status: "*" + patch_optional: true + tags: [coverage] + + # Doctest tests + - name: test-doctests + commands: + - func: run server + - func: run just script + vars: + JUSTFILE_TARGET: docs-test + tags: [doctests] + + # Enterprise auth tests + - name: test-enterprise-auth-python3.9 + commands: + - func: run server + vars: + TEST_NAME: enterprise_auth + AUTH: auth + PYTHON_VERSION: "3.9" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: enterprise_auth + AUTH: auth + PYTHON_VERSION: "3.9" + tags: [enterprise_auth] + - name: test-enterprise-auth-python3.13 + commands: + - func: run server + vars: + TEST_NAME: enterprise_auth + AUTH: auth + PYTHON_VERSION: "3.13" + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: enterprise_auth + AUTH: auth + PYTHON_VERSION: "3.13" + tags: [enterprise_auth] + - name: test-enterprise-auth-pypy3.10 + commands: + - func: run server + vars: + TEST_NAME: enterprise_auth + AUTH: auth + PYTHON_VERSION: pypy3.10 + - func: assume ec2 role + - func: run tests + vars: + TEST_NAME: enterprise_auth + AUTH: auth + PYTHON_VERSION: pypy3.10 + tags: [enterprise_auth, pypy] + + # Free threading tests + - name: test-free-threading + commands: + - func: run server + vars: + VERSION: "8.0" + TOPOLOGY: replica_set + - func: run tests + tags: [free-threading] + + # Getdata tests + - name: getdata + commands: + - command: subprocess.exec + params: + binary: bash + args: + - .evergreen/scripts/run-getdata.sh + working_dir: src + type: test + + # Import time tests + - name: check-import-time + commands: + - command: subprocess.exec + params: + binary: bash + args: + - .evergreen/scripts/check-import-time.sh + - ${revision} + - ${github_commit} + working_dir: src + type: test + tags: [pr] + + # Kms tests + - name: test-gcpkms + commands: + - func: run tests + vars: + TEST_NAME: kms + SUB_TEST_NAME: gcp + - name: test-gcpkms-fail + commands: + - func: run server + - func: run tests + vars: + TEST_NAME: kms + SUB_TEST_NAME: gcp-fail + - name: test-azurekms + commands: + - func: run tests + vars: + TEST_NAME: kms + SUB_TEST_NAME: azure + - name: test-azurekms-fail + commands: + - func: run server + - func: run tests + vars: + TEST_NAME: kms + SUB_TEST_NAME: azure-fail + # Load balancer tests - - name: test-load-balancer-auth-ssl + - name: test-load-balancer-auth-ssl-v6.0 + commands: + - func: run server + vars: + TOPOLOGY: sharded_cluster + AUTH: auth + SSL: ssl + TEST_NAME: load_balancer + VERSION: "6.0" + - func: run tests + vars: + AUTH: auth + SSL: ssl + TEST_NAME: load_balancer + tags: [load-balancer, auth, ssl] + - name: test-load-balancer-auth-ssl-v7.0 + commands: + - func: run server + vars: + TOPOLOGY: sharded_cluster + AUTH: auth + SSL: ssl + TEST_NAME: load_balancer + VERSION: "7.0" + - func: run tests + vars: + AUTH: auth + SSL: ssl + TEST_NAME: load_balancer + tags: [load-balancer, auth, ssl] + - name: test-load-balancer-auth-ssl-v8.0 + commands: + - func: run server + vars: + TOPOLOGY: sharded_cluster + AUTH: auth + SSL: ssl + TEST_NAME: load_balancer + VERSION: "8.0" + - func: run tests + vars: + AUTH: auth + SSL: ssl + TEST_NAME: load_balancer + tags: [load-balancer, auth, ssl] + - name: test-load-balancer-auth-ssl-rapid + commands: + - func: run server + vars: + TOPOLOGY: sharded_cluster + AUTH: auth + SSL: ssl + TEST_NAME: load_balancer + VERSION: rapid + - func: run tests + vars: + AUTH: auth + SSL: ssl + TEST_NAME: load_balancer + tags: [load-balancer, auth, ssl] + - name: test-load-balancer-auth-ssl-latest + commands: + - func: run server + vars: + TOPOLOGY: sharded_cluster + AUTH: auth + SSL: ssl + TEST_NAME: load_balancer + VERSION: latest + - func: run tests + vars: + AUTH: auth + SSL: ssl + TEST_NAME: load_balancer + tags: [load-balancer, auth, ssl] + - name: test-load-balancer-noauth-ssl-v6.0 + commands: + - func: run server + vars: + TOPOLOGY: sharded_cluster + AUTH: noauth + SSL: ssl + TEST_NAME: load_balancer + VERSION: "6.0" + - func: run tests + vars: + AUTH: noauth + SSL: ssl + TEST_NAME: load_balancer + tags: [load-balancer, noauth, ssl] + - name: test-load-balancer-noauth-ssl-v7.0 + commands: + - func: run server + vars: + TOPOLOGY: sharded_cluster + AUTH: noauth + SSL: ssl + TEST_NAME: load_balancer + VERSION: "7.0" + - func: run tests + vars: + AUTH: noauth + SSL: ssl + TEST_NAME: load_balancer + tags: [load-balancer, noauth, ssl] + - name: test-load-balancer-noauth-ssl-v8.0 + commands: + - func: run server + vars: + TOPOLOGY: sharded_cluster + AUTH: noauth + SSL: ssl + TEST_NAME: load_balancer + VERSION: "8.0" + - func: run tests + vars: + AUTH: noauth + SSL: ssl + TEST_NAME: load_balancer + tags: [load-balancer, noauth, ssl] + - name: test-load-balancer-noauth-ssl-rapid + commands: + - func: run server + vars: + TOPOLOGY: sharded_cluster + AUTH: noauth + SSL: ssl + TEST_NAME: load_balancer + VERSION: rapid + - func: run tests + vars: + AUTH: noauth + SSL: ssl + TEST_NAME: load_balancer + tags: [load-balancer, noauth, ssl] + - name: test-load-balancer-noauth-ssl-latest + commands: + - func: run server + vars: + TOPOLOGY: sharded_cluster + AUTH: noauth + SSL: ssl + TEST_NAME: load_balancer + VERSION: latest + - func: run tests + vars: + AUTH: noauth + SSL: ssl + TEST_NAME: load_balancer + tags: [load-balancer, noauth, ssl] + - name: test-load-balancer-noauth-nossl-v6.0 + commands: + - func: run server + vars: + TOPOLOGY: sharded_cluster + AUTH: noauth + SSL: nossl + TEST_NAME: load_balancer + VERSION: "6.0" + - func: run tests + vars: + AUTH: noauth + SSL: nossl + TEST_NAME: load_balancer + tags: [load-balancer, noauth, nossl] + - name: test-load-balancer-noauth-nossl-v7.0 + commands: + - func: run server + vars: + TOPOLOGY: sharded_cluster + AUTH: noauth + SSL: nossl + TEST_NAME: load_balancer + VERSION: "7.0" + - func: run tests + vars: + AUTH: noauth + SSL: nossl + TEST_NAME: load_balancer + tags: [load-balancer, noauth, nossl] + - name: test-load-balancer-noauth-nossl-v8.0 + commands: + - func: run server + vars: + TOPOLOGY: sharded_cluster + AUTH: noauth + SSL: nossl + TEST_NAME: load_balancer + VERSION: "8.0" + - func: run tests + vars: + AUTH: noauth + SSL: nossl + TEST_NAME: load_balancer + tags: [load-balancer, noauth, nossl] + - name: test-load-balancer-noauth-nossl-rapid + commands: + - func: run server + vars: + TOPOLOGY: sharded_cluster + AUTH: noauth + SSL: nossl + TEST_NAME: load_balancer + VERSION: rapid + - func: run tests + vars: + AUTH: noauth + SSL: nossl + TEST_NAME: load_balancer + tags: [load-balancer, noauth, nossl] + - name: test-load-balancer-noauth-nossl-latest + commands: + - func: run server + vars: + TOPOLOGY: sharded_cluster + AUTH: noauth + SSL: nossl + TEST_NAME: load_balancer + VERSION: latest + - func: run tests + vars: + AUTH: noauth + SSL: nossl + TEST_NAME: load_balancer + tags: [load-balancer, noauth, nossl] + + # Mockupdb tests + - name: test-mockupdb + commands: + - func: run tests + vars: + TEST_NAME: mockupdb + tags: [mockupdb] + + # Mod wsgi tests + - name: mod-wsgi-standalone + commands: + - func: run server + vars: + TOPOLOGY: standalone + - func: run tests + vars: + TEST_NAME: mod_wsgi + SUB_TEST_NAME: standalone + tags: [mod_wsgi] + - name: mod-wsgi-replica-set + commands: + - func: run server + vars: + TOPOLOGY: replica_set + - func: run tests + vars: + TEST_NAME: mod_wsgi + SUB_TEST_NAME: standalone + tags: [mod_wsgi] + - name: mod-wsgi-embedded-mode-standalone + commands: + - func: run server + vars: + TOPOLOGY: standalone + - func: run tests + vars: + TEST_NAME: mod_wsgi + SUB_TEST_NAME: embedded + tags: [mod_wsgi] + - name: mod-wsgi-embedded-mode-replica-set + commands: + - func: run server + vars: + TOPOLOGY: replica_set + - func: run tests + vars: + TEST_NAME: mod_wsgi + SUB_TEST_NAME: embedded + tags: [mod_wsgi] + + # No server tests + - name: test-no-server + commands: + - func: run tests + tags: [no-server] + + # Ocsp tests + - name: test-ocsp-ecdsa-valid-cert-server-does-not-staple-v4.4-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: valid + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "4.4" + tags: [ocsp, ocsp-ecdsa, "4.4"] + - name: test-ocsp-ecdsa-valid-cert-server-does-not-staple-v5.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: valid + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "5.0" + tags: [ocsp, ocsp-ecdsa, "5.0"] + - name: test-ocsp-ecdsa-valid-cert-server-does-not-staple-v6.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: valid + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "6.0" + tags: [ocsp, ocsp-ecdsa, "6.0"] + - name: test-ocsp-ecdsa-valid-cert-server-does-not-staple-v7.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: valid + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "7.0" + tags: [ocsp, ocsp-ecdsa, "7.0"] + - name: test-ocsp-ecdsa-valid-cert-server-does-not-staple-v8.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: valid + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "8.0" + tags: [ocsp, ocsp-ecdsa, "8.0"] + - name: test-ocsp-ecdsa-valid-cert-server-does-not-staple-rapid-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: valid + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: rapid + tags: [ocsp, ocsp-ecdsa, rapid] + - name: test-ocsp-ecdsa-valid-cert-server-does-not-staple-latest-python3.13 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: valid + TEST_NAME: ocsp + PYTHON_VERSION: "3.13" + VERSION: latest + tags: [ocsp, ocsp-ecdsa, latest] + - name: test-ocsp-ecdsa-invalid-cert-server-does-not-staple-v4.4-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "4.4" + tags: [ocsp, ocsp-ecdsa, "4.4"] + - name: test-ocsp-ecdsa-invalid-cert-server-does-not-staple-v5.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "5.0" + tags: [ocsp, ocsp-ecdsa, "5.0"] + - name: test-ocsp-ecdsa-invalid-cert-server-does-not-staple-v6.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "6.0" + tags: [ocsp, ocsp-ecdsa, "6.0"] + - name: test-ocsp-ecdsa-invalid-cert-server-does-not-staple-v7.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "7.0" + tags: [ocsp, ocsp-ecdsa, "7.0"] + - name: test-ocsp-ecdsa-invalid-cert-server-does-not-staple-v8.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "8.0" + tags: [ocsp, ocsp-ecdsa, "8.0"] + - name: test-ocsp-ecdsa-invalid-cert-server-does-not-staple-rapid-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: rapid + tags: [ocsp, ocsp-ecdsa, rapid] + - name: test-ocsp-ecdsa-invalid-cert-server-does-not-staple-latest-python3.13 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.13" + VERSION: latest + tags: [ocsp, ocsp-ecdsa, latest] + - name: test-ocsp-ecdsa-delegate-valid-cert-server-does-not-staple-v4.4-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: valid-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "4.4" + tags: [ocsp, ocsp-ecdsa, "4.4"] + - name: test-ocsp-ecdsa-delegate-valid-cert-server-does-not-staple-v5.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: valid-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "5.0" + tags: [ocsp, ocsp-ecdsa, "5.0"] + - name: test-ocsp-ecdsa-delegate-valid-cert-server-does-not-staple-v6.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: valid-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "6.0" + tags: [ocsp, ocsp-ecdsa, "6.0"] + - name: test-ocsp-ecdsa-delegate-valid-cert-server-does-not-staple-v7.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: valid-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "7.0" + tags: [ocsp, ocsp-ecdsa, "7.0"] + - name: test-ocsp-ecdsa-delegate-valid-cert-server-does-not-staple-v8.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: valid-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "8.0" + tags: [ocsp, ocsp-ecdsa, "8.0"] + - name: test-ocsp-ecdsa-delegate-valid-cert-server-does-not-staple-rapid-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: valid-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: rapid + tags: [ocsp, ocsp-ecdsa, rapid] + - name: test-ocsp-ecdsa-delegate-valid-cert-server-does-not-staple-latest-python3.13 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: valid-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.13" + VERSION: latest + tags: [ocsp, ocsp-ecdsa, latest] + - name: test-ocsp-ecdsa-delegate-invalid-cert-server-does-not-staple-v4.4-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "4.4" + tags: [ocsp, ocsp-ecdsa, "4.4"] + - name: test-ocsp-ecdsa-delegate-invalid-cert-server-does-not-staple-v5.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "5.0" + tags: [ocsp, ocsp-ecdsa, "5.0"] + - name: test-ocsp-ecdsa-delegate-invalid-cert-server-does-not-staple-v6.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "6.0" + tags: [ocsp, ocsp-ecdsa, "6.0"] + - name: test-ocsp-ecdsa-delegate-invalid-cert-server-does-not-staple-v7.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "7.0" + tags: [ocsp, ocsp-ecdsa, "7.0"] + - name: test-ocsp-ecdsa-delegate-invalid-cert-server-does-not-staple-v8.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "8.0" + tags: [ocsp, ocsp-ecdsa, "8.0"] + - name: test-ocsp-ecdsa-delegate-invalid-cert-server-does-not-staple-rapid-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: rapid + tags: [ocsp, ocsp-ecdsa, rapid] + - name: test-ocsp-ecdsa-delegate-invalid-cert-server-does-not-staple-latest-python3.13 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.13" + VERSION: latest + tags: [ocsp, ocsp-ecdsa, latest] + - name: test-ocsp-ecdsa-soft-fail-v4.4-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: no-responder + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "4.4" + tags: [ocsp, ocsp-ecdsa, "4.4"] + - name: test-ocsp-ecdsa-soft-fail-v5.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: no-responder + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "5.0" + tags: [ocsp, ocsp-ecdsa, "5.0"] + - name: test-ocsp-ecdsa-soft-fail-v6.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: no-responder + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "6.0" + tags: [ocsp, ocsp-ecdsa, "6.0"] + - name: test-ocsp-ecdsa-soft-fail-v7.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: no-responder + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "7.0" + tags: [ocsp, ocsp-ecdsa, "7.0"] + - name: test-ocsp-ecdsa-soft-fail-v8.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: no-responder + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "8.0" + tags: [ocsp, ocsp-ecdsa, "8.0"] + - name: test-ocsp-ecdsa-soft-fail-rapid-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: no-responder + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: rapid + tags: [ocsp, ocsp-ecdsa, rapid] + - name: test-ocsp-ecdsa-soft-fail-latest-python3.13 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: no-responder + TEST_NAME: ocsp + PYTHON_VERSION: "3.13" + VERSION: latest + tags: [ocsp, ocsp-ecdsa, latest] + - name: test-ocsp-ecdsa-valid-cert-server-staples-v4.4-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: valid + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "4.4" + tags: + - ocsp + - ocsp-ecdsa + - "4.4" + - ocsp-staple + - name: test-ocsp-ecdsa-valid-cert-server-staples-v5.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: valid + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "5.0" + tags: + - ocsp + - ocsp-ecdsa + - "5.0" + - ocsp-staple + - name: test-ocsp-ecdsa-valid-cert-server-staples-v6.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: valid + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "6.0" + tags: + - ocsp + - ocsp-ecdsa + - "6.0" + - ocsp-staple + - name: test-ocsp-ecdsa-valid-cert-server-staples-v7.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: valid + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "7.0" + tags: + - ocsp + - ocsp-ecdsa + - "7.0" + - ocsp-staple + - name: test-ocsp-ecdsa-valid-cert-server-staples-v8.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: valid + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "8.0" + tags: + - ocsp + - ocsp-ecdsa + - "8.0" + - ocsp-staple + - name: test-ocsp-ecdsa-valid-cert-server-staples-rapid-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: valid + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: rapid + tags: + - ocsp + - ocsp-ecdsa + - rapid + - ocsp-staple + - name: test-ocsp-ecdsa-valid-cert-server-staples-latest-python3.13 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: valid + TEST_NAME: ocsp + PYTHON_VERSION: "3.13" + VERSION: latest + tags: + - ocsp + - ocsp-ecdsa + - latest + - ocsp-staple + - name: test-ocsp-ecdsa-invalid-cert-server-staples-v4.4-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "4.4" + tags: + - ocsp + - ocsp-ecdsa + - "4.4" + - ocsp-staple + - name: test-ocsp-ecdsa-invalid-cert-server-staples-v5.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "5.0" + tags: + - ocsp + - ocsp-ecdsa + - "5.0" + - ocsp-staple + - name: test-ocsp-ecdsa-invalid-cert-server-staples-v6.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "6.0" + tags: + - ocsp + - ocsp-ecdsa + - "6.0" + - ocsp-staple + - name: test-ocsp-ecdsa-invalid-cert-server-staples-v7.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "7.0" + tags: + - ocsp + - ocsp-ecdsa + - "7.0" + - ocsp-staple + - name: test-ocsp-ecdsa-invalid-cert-server-staples-v8.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "8.0" + tags: + - ocsp + - ocsp-ecdsa + - "8.0" + - ocsp-staple + - name: test-ocsp-ecdsa-invalid-cert-server-staples-rapid-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: rapid + tags: + - ocsp + - ocsp-ecdsa + - rapid + - ocsp-staple + - name: test-ocsp-ecdsa-invalid-cert-server-staples-latest-python3.13 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.13" + VERSION: latest + tags: + - ocsp + - ocsp-ecdsa + - latest + - ocsp-staple + - name: test-ocsp-ecdsa-delegate-valid-cert-server-staples-v4.4-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: valid-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "4.4" + tags: + - ocsp + - ocsp-ecdsa + - "4.4" + - ocsp-staple + - name: test-ocsp-ecdsa-delegate-valid-cert-server-staples-v5.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: valid-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "5.0" + tags: + - ocsp + - ocsp-ecdsa + - "5.0" + - ocsp-staple + - name: test-ocsp-ecdsa-delegate-valid-cert-server-staples-v6.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: valid-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "6.0" + tags: + - ocsp + - ocsp-ecdsa + - "6.0" + - ocsp-staple + - name: test-ocsp-ecdsa-delegate-valid-cert-server-staples-v7.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: valid-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "7.0" + tags: + - ocsp + - ocsp-ecdsa + - "7.0" + - ocsp-staple + - name: test-ocsp-ecdsa-delegate-valid-cert-server-staples-v8.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: valid-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "8.0" + tags: + - ocsp + - ocsp-ecdsa + - "8.0" + - ocsp-staple + - name: test-ocsp-ecdsa-delegate-valid-cert-server-staples-rapid-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: valid-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: rapid + tags: + - ocsp + - ocsp-ecdsa + - rapid + - ocsp-staple + - name: test-ocsp-ecdsa-delegate-valid-cert-server-staples-latest-python3.13 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: valid-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.13" + VERSION: latest + tags: + - ocsp + - ocsp-ecdsa + - latest + - ocsp-staple + - name: test-ocsp-ecdsa-delegate-invalid-cert-server-staples-v4.4-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "4.4" + tags: + - ocsp + - ocsp-ecdsa + - "4.4" + - ocsp-staple + - name: test-ocsp-ecdsa-delegate-invalid-cert-server-staples-v5.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "5.0" + tags: + - ocsp + - ocsp-ecdsa + - "5.0" + - ocsp-staple + - name: test-ocsp-ecdsa-delegate-invalid-cert-server-staples-v6.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "6.0" + tags: + - ocsp + - ocsp-ecdsa + - "6.0" + - ocsp-staple + - name: test-ocsp-ecdsa-delegate-invalid-cert-server-staples-v7.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "7.0" + tags: + - ocsp + - ocsp-ecdsa + - "7.0" + - ocsp-staple + - name: test-ocsp-ecdsa-delegate-invalid-cert-server-staples-v8.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "8.0" + tags: + - ocsp + - ocsp-ecdsa + - "8.0" + - ocsp-staple + - name: test-ocsp-ecdsa-delegate-invalid-cert-server-staples-rapid-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: rapid + tags: + - ocsp + - ocsp-ecdsa + - rapid + - ocsp-staple + - name: test-ocsp-ecdsa-delegate-invalid-cert-server-staples-latest-python3.13 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.13" + VERSION: latest + tags: + - ocsp + - ocsp-ecdsa + - latest + - ocsp-staple + - name: test-ocsp-ecdsa-malicious-invalid-cert-muststaple-server-does-not-staple-v4.4-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "4.4" + tags: [ocsp, ocsp-ecdsa, "4.4"] + - name: test-ocsp-ecdsa-malicious-invalid-cert-muststaple-server-does-not-staple-v5.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "5.0" + tags: [ocsp, ocsp-ecdsa, "5.0"] + - name: test-ocsp-ecdsa-malicious-invalid-cert-muststaple-server-does-not-staple-v6.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "6.0" + tags: [ocsp, ocsp-ecdsa, "6.0"] + - name: test-ocsp-ecdsa-malicious-invalid-cert-muststaple-server-does-not-staple-v7.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "7.0" + tags: [ocsp, ocsp-ecdsa, "7.0"] + - name: test-ocsp-ecdsa-malicious-invalid-cert-muststaple-server-does-not-staple-v8.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "8.0" + tags: [ocsp, ocsp-ecdsa, "8.0"] + - name: test-ocsp-ecdsa-malicious-invalid-cert-muststaple-server-does-not-staple-rapid-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: rapid + tags: [ocsp, ocsp-ecdsa, rapid] + - name: test-ocsp-ecdsa-malicious-invalid-cert-muststaple-server-does-not-staple-latest-python3.13 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.13" + VERSION: latest + tags: [ocsp, ocsp-ecdsa, latest] + - name: test-ocsp-ecdsa-delegate-malicious-invalid-cert-muststaple-server-does-not-staple-v4.4-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "4.4" + tags: [ocsp, ocsp-ecdsa, "4.4"] + - name: test-ocsp-ecdsa-delegate-malicious-invalid-cert-muststaple-server-does-not-staple-v5.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "5.0" + tags: [ocsp, ocsp-ecdsa, "5.0"] + - name: test-ocsp-ecdsa-delegate-malicious-invalid-cert-muststaple-server-does-not-staple-v6.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "6.0" + tags: [ocsp, ocsp-ecdsa, "6.0"] + - name: test-ocsp-ecdsa-delegate-malicious-invalid-cert-muststaple-server-does-not-staple-v7.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "7.0" + tags: [ocsp, ocsp-ecdsa, "7.0"] + - name: test-ocsp-ecdsa-delegate-malicious-invalid-cert-muststaple-server-does-not-staple-v8.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "8.0" + tags: [ocsp, ocsp-ecdsa, "8.0"] + - name: test-ocsp-ecdsa-delegate-malicious-invalid-cert-muststaple-server-does-not-staple-rapid-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: rapid + tags: [ocsp, ocsp-ecdsa, rapid] + - name: test-ocsp-ecdsa-delegate-malicious-invalid-cert-muststaple-server-does-not-staple-latest-python3.13 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.13" + VERSION: latest + tags: [ocsp, ocsp-ecdsa, latest] + - name: test-ocsp-ecdsa-malicious-no-responder-muststaple-server-does-not-staple-v4.4-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: no-responder + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "4.4" + tags: [ocsp, ocsp-ecdsa, "4.4"] + - name: test-ocsp-ecdsa-malicious-no-responder-muststaple-server-does-not-staple-v5.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: no-responder + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "5.0" + tags: [ocsp, ocsp-ecdsa, "5.0"] + - name: test-ocsp-ecdsa-malicious-no-responder-muststaple-server-does-not-staple-v6.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: no-responder + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "6.0" + tags: [ocsp, ocsp-ecdsa, "6.0"] + - name: test-ocsp-ecdsa-malicious-no-responder-muststaple-server-does-not-staple-v7.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: no-responder + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "7.0" + tags: [ocsp, ocsp-ecdsa, "7.0"] + - name: test-ocsp-ecdsa-malicious-no-responder-muststaple-server-does-not-staple-v8.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: no-responder + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "8.0" + tags: [ocsp, ocsp-ecdsa, "8.0"] + - name: test-ocsp-ecdsa-malicious-no-responder-muststaple-server-does-not-staple-rapid-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: no-responder + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: rapid + tags: [ocsp, ocsp-ecdsa, rapid] + - name: test-ocsp-ecdsa-malicious-no-responder-muststaple-server-does-not-staple-latest-python3.13 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: ecdsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: no-responder + TEST_NAME: ocsp + PYTHON_VERSION: "3.13" + VERSION: latest + tags: [ocsp, ocsp-ecdsa, latest] + - name: test-ocsp-rsa-valid-cert-server-does-not-staple-v4.4-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: valid + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "4.4" + tags: [ocsp, ocsp-rsa, "4.4"] + - name: test-ocsp-rsa-valid-cert-server-does-not-staple-v5.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: valid + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "5.0" + tags: [ocsp, ocsp-rsa, "5.0"] + - name: test-ocsp-rsa-valid-cert-server-does-not-staple-v6.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: valid + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "6.0" + tags: [ocsp, ocsp-rsa, "6.0"] + - name: test-ocsp-rsa-valid-cert-server-does-not-staple-v7.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: valid + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "7.0" + tags: [ocsp, ocsp-rsa, "7.0"] + - name: test-ocsp-rsa-valid-cert-server-does-not-staple-v8.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: valid + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "8.0" + tags: [ocsp, ocsp-rsa, "8.0"] + - name: test-ocsp-rsa-valid-cert-server-does-not-staple-rapid-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: valid + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: rapid + tags: [ocsp, ocsp-rsa, rapid] + - name: test-ocsp-rsa-valid-cert-server-does-not-staple-latest-python3.13 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: valid + TEST_NAME: ocsp + PYTHON_VERSION: "3.13" + VERSION: latest + tags: [ocsp, ocsp-rsa, latest] + - name: test-ocsp-rsa-invalid-cert-server-does-not-staple-v4.4-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "4.4" + tags: [ocsp, ocsp-rsa, "4.4"] + - name: test-ocsp-rsa-invalid-cert-server-does-not-staple-v5.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "5.0" + tags: [ocsp, ocsp-rsa, "5.0"] + - name: test-ocsp-rsa-invalid-cert-server-does-not-staple-v6.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "6.0" + tags: [ocsp, ocsp-rsa, "6.0"] + - name: test-ocsp-rsa-invalid-cert-server-does-not-staple-v7.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "7.0" + tags: [ocsp, ocsp-rsa, "7.0"] + - name: test-ocsp-rsa-invalid-cert-server-does-not-staple-v8.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "8.0" + tags: [ocsp, ocsp-rsa, "8.0"] + - name: test-ocsp-rsa-invalid-cert-server-does-not-staple-rapid-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: rapid + tags: [ocsp, ocsp-rsa, rapid] + - name: test-ocsp-rsa-invalid-cert-server-does-not-staple-latest-python3.13 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.13" + VERSION: latest + tags: [ocsp, ocsp-rsa, latest] + - name: test-ocsp-rsa-delegate-valid-cert-server-does-not-staple-v4.4-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: valid-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "4.4" + tags: [ocsp, ocsp-rsa, "4.4"] + - name: test-ocsp-rsa-delegate-valid-cert-server-does-not-staple-v5.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: valid-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "5.0" + tags: [ocsp, ocsp-rsa, "5.0"] + - name: test-ocsp-rsa-delegate-valid-cert-server-does-not-staple-v6.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: valid-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "6.0" + tags: [ocsp, ocsp-rsa, "6.0"] + - name: test-ocsp-rsa-delegate-valid-cert-server-does-not-staple-v7.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: valid-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "7.0" + tags: [ocsp, ocsp-rsa, "7.0"] + - name: test-ocsp-rsa-delegate-valid-cert-server-does-not-staple-v8.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: valid-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "8.0" + tags: [ocsp, ocsp-rsa, "8.0"] + - name: test-ocsp-rsa-delegate-valid-cert-server-does-not-staple-rapid-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: valid-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: rapid + tags: [ocsp, ocsp-rsa, rapid] + - name: test-ocsp-rsa-delegate-valid-cert-server-does-not-staple-latest-python3.13 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: valid-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.13" + VERSION: latest + tags: [ocsp, ocsp-rsa, latest] + - name: test-ocsp-rsa-delegate-invalid-cert-server-does-not-staple-v4.4-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "4.4" + tags: [ocsp, ocsp-rsa, "4.4"] + - name: test-ocsp-rsa-delegate-invalid-cert-server-does-not-staple-v5.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "5.0" + tags: [ocsp, ocsp-rsa, "5.0"] + - name: test-ocsp-rsa-delegate-invalid-cert-server-does-not-staple-v6.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "6.0" + tags: [ocsp, ocsp-rsa, "6.0"] + - name: test-ocsp-rsa-delegate-invalid-cert-server-does-not-staple-v7.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "7.0" + tags: [ocsp, ocsp-rsa, "7.0"] + - name: test-ocsp-rsa-delegate-invalid-cert-server-does-not-staple-v8.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "8.0" + tags: [ocsp, ocsp-rsa, "8.0"] + - name: test-ocsp-rsa-delegate-invalid-cert-server-does-not-staple-rapid-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: rapid + tags: [ocsp, ocsp-rsa, rapid] + - name: test-ocsp-rsa-delegate-invalid-cert-server-does-not-staple-latest-python3.13 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.13" + VERSION: latest + tags: [ocsp, ocsp-rsa, latest] + - name: test-ocsp-rsa-soft-fail-v4.4-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: no-responder + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "4.4" + tags: [ocsp, ocsp-rsa, "4.4"] + - name: test-ocsp-rsa-soft-fail-v5.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: no-responder + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "5.0" + tags: [ocsp, ocsp-rsa, "5.0"] + - name: test-ocsp-rsa-soft-fail-v6.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: no-responder + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "6.0" + tags: [ocsp, ocsp-rsa, "6.0"] + - name: test-ocsp-rsa-soft-fail-v7.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: no-responder + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "7.0" + tags: [ocsp, ocsp-rsa, "7.0"] + - name: test-ocsp-rsa-soft-fail-v8.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: no-responder + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "8.0" + tags: [ocsp, ocsp-rsa, "8.0"] + - name: test-ocsp-rsa-soft-fail-rapid-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: no-responder + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: rapid + tags: [ocsp, ocsp-rsa, rapid] + - name: test-ocsp-rsa-soft-fail-latest-python3.13 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-disableStapling.json + OCSP_SERVER_TYPE: no-responder + TEST_NAME: ocsp + PYTHON_VERSION: "3.13" + VERSION: latest + tags: [ocsp, ocsp-rsa, latest] + - name: test-ocsp-rsa-valid-cert-server-staples-v4.4-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: valid + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "4.4" + tags: + - ocsp + - ocsp-rsa + - "4.4" + - ocsp-staple + - name: test-ocsp-rsa-valid-cert-server-staples-v5.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: valid + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "5.0" + tags: + - ocsp + - ocsp-rsa + - "5.0" + - ocsp-staple + - name: test-ocsp-rsa-valid-cert-server-staples-v6.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: valid + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "6.0" + tags: + - ocsp + - ocsp-rsa + - "6.0" + - ocsp-staple + - name: test-ocsp-rsa-valid-cert-server-staples-v7.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: valid + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "7.0" + tags: + - ocsp + - ocsp-rsa + - "7.0" + - ocsp-staple + - name: test-ocsp-rsa-valid-cert-server-staples-v8.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: valid + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "8.0" + tags: + - ocsp + - ocsp-rsa + - "8.0" + - ocsp-staple + - name: test-ocsp-rsa-valid-cert-server-staples-rapid-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: valid + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: rapid + tags: + - ocsp + - ocsp-rsa + - rapid + - ocsp-staple + - name: test-ocsp-rsa-valid-cert-server-staples-latest-python3.13 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: valid + TEST_NAME: ocsp + PYTHON_VERSION: "3.13" + VERSION: latest + tags: + - ocsp + - ocsp-rsa + - latest + - ocsp-staple + - name: test-ocsp-rsa-invalid-cert-server-staples-v4.4-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "4.4" + tags: + - ocsp + - ocsp-rsa + - "4.4" + - ocsp-staple + - name: test-ocsp-rsa-invalid-cert-server-staples-v5.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "5.0" + tags: + - ocsp + - ocsp-rsa + - "5.0" + - ocsp-staple + - name: test-ocsp-rsa-invalid-cert-server-staples-v6.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "6.0" + tags: + - ocsp + - ocsp-rsa + - "6.0" + - ocsp-staple + - name: test-ocsp-rsa-invalid-cert-server-staples-v7.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "7.0" + tags: + - ocsp + - ocsp-rsa + - "7.0" + - ocsp-staple + - name: test-ocsp-rsa-invalid-cert-server-staples-v8.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "8.0" + tags: + - ocsp + - ocsp-rsa + - "8.0" + - ocsp-staple + - name: test-ocsp-rsa-invalid-cert-server-staples-rapid-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: rapid + tags: + - ocsp + - ocsp-rsa + - rapid + - ocsp-staple + - name: test-ocsp-rsa-invalid-cert-server-staples-latest-python3.13 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.13" + VERSION: latest + tags: + - ocsp + - ocsp-rsa + - latest + - ocsp-staple + - name: test-ocsp-rsa-delegate-valid-cert-server-staples-v4.4-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: valid-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "4.4" + tags: + - ocsp + - ocsp-rsa + - "4.4" + - ocsp-staple + - name: test-ocsp-rsa-delegate-valid-cert-server-staples-v5.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: valid-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "5.0" + tags: + - ocsp + - ocsp-rsa + - "5.0" + - ocsp-staple + - name: test-ocsp-rsa-delegate-valid-cert-server-staples-v6.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: valid-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "6.0" + tags: + - ocsp + - ocsp-rsa + - "6.0" + - ocsp-staple + - name: test-ocsp-rsa-delegate-valid-cert-server-staples-v7.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: valid-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "7.0" + tags: + - ocsp + - ocsp-rsa + - "7.0" + - ocsp-staple + - name: test-ocsp-rsa-delegate-valid-cert-server-staples-v8.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: valid-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "8.0" + tags: + - ocsp + - ocsp-rsa + - "8.0" + - ocsp-staple + - name: test-ocsp-rsa-delegate-valid-cert-server-staples-rapid-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: valid-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: rapid + tags: + - ocsp + - ocsp-rsa + - rapid + - ocsp-staple + - name: test-ocsp-rsa-delegate-valid-cert-server-staples-latest-python3.13 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: valid-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.13" + VERSION: latest + tags: + - ocsp + - ocsp-rsa + - latest + - ocsp-staple + - name: test-ocsp-rsa-delegate-invalid-cert-server-staples-v4.4-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "4.4" + tags: + - ocsp + - ocsp-rsa + - "4.4" + - ocsp-staple + - name: test-ocsp-rsa-delegate-invalid-cert-server-staples-v5.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "5.0" + tags: + - ocsp + - ocsp-rsa + - "5.0" + - ocsp-staple + - name: test-ocsp-rsa-delegate-invalid-cert-server-staples-v6.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "6.0" + tags: + - ocsp + - ocsp-rsa + - "6.0" + - ocsp-staple + - name: test-ocsp-rsa-delegate-invalid-cert-server-staples-v7.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "7.0" + tags: + - ocsp + - ocsp-rsa + - "7.0" + - ocsp-staple + - name: test-ocsp-rsa-delegate-invalid-cert-server-staples-v8.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "8.0" + tags: + - ocsp + - ocsp-rsa + - "8.0" + - ocsp-staple + - name: test-ocsp-rsa-delegate-invalid-cert-server-staples-rapid-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: rapid + tags: + - ocsp + - ocsp-rsa + - rapid + - ocsp-staple + - name: test-ocsp-rsa-delegate-invalid-cert-server-staples-latest-python3.13 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.13" + VERSION: latest + tags: + - ocsp + - ocsp-rsa + - latest + - ocsp-staple + - name: test-ocsp-rsa-malicious-invalid-cert-muststaple-server-does-not-staple-v4.4-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "4.4" + tags: [ocsp, ocsp-rsa, "4.4"] + - name: test-ocsp-rsa-malicious-invalid-cert-muststaple-server-does-not-staple-v5.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "5.0" + tags: [ocsp, ocsp-rsa, "5.0"] + - name: test-ocsp-rsa-malicious-invalid-cert-muststaple-server-does-not-staple-v6.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "6.0" + tags: [ocsp, ocsp-rsa, "6.0"] + - name: test-ocsp-rsa-malicious-invalid-cert-muststaple-server-does-not-staple-v7.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "7.0" + tags: [ocsp, ocsp-rsa, "7.0"] + - name: test-ocsp-rsa-malicious-invalid-cert-muststaple-server-does-not-staple-v8.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "8.0" + tags: [ocsp, ocsp-rsa, "8.0"] + - name: test-ocsp-rsa-malicious-invalid-cert-muststaple-server-does-not-staple-rapid-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: rapid + tags: [ocsp, ocsp-rsa, rapid] + - name: test-ocsp-rsa-malicious-invalid-cert-muststaple-server-does-not-staple-latest-python3.13 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: revoked + TEST_NAME: ocsp + PYTHON_VERSION: "3.13" + VERSION: latest + tags: [ocsp, ocsp-rsa, latest] + - name: test-ocsp-rsa-delegate-malicious-invalid-cert-muststaple-server-does-not-staple-v4.4-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "4.4" + tags: [ocsp, ocsp-rsa, "4.4"] + - name: test-ocsp-rsa-delegate-malicious-invalid-cert-muststaple-server-does-not-staple-v5.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "5.0" + tags: [ocsp, ocsp-rsa, "5.0"] + - name: test-ocsp-rsa-delegate-malicious-invalid-cert-muststaple-server-does-not-staple-v6.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "6.0" + tags: [ocsp, ocsp-rsa, "6.0"] + - name: test-ocsp-rsa-delegate-malicious-invalid-cert-muststaple-server-does-not-staple-v7.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "7.0" + tags: [ocsp, ocsp-rsa, "7.0"] + - name: test-ocsp-rsa-delegate-malicious-invalid-cert-muststaple-server-does-not-staple-v8.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "8.0" + tags: [ocsp, ocsp-rsa, "8.0"] + - name: test-ocsp-rsa-delegate-malicious-invalid-cert-muststaple-server-does-not-staple-rapid-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: rapid + tags: [ocsp, ocsp-rsa, rapid] + - name: test-ocsp-rsa-delegate-malicious-invalid-cert-muststaple-server-does-not-staple-latest-python3.13 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: revoked-delegate + TEST_NAME: ocsp + PYTHON_VERSION: "3.13" + VERSION: latest + tags: [ocsp, ocsp-rsa, latest] + - name: test-ocsp-rsa-malicious-no-responder-muststaple-server-does-not-staple-v4.4-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: no-responder + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "4.4" + tags: [ocsp, ocsp-rsa, "4.4"] + - name: test-ocsp-rsa-malicious-no-responder-muststaple-server-does-not-staple-v5.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: no-responder + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "5.0" + tags: [ocsp, ocsp-rsa, "5.0"] + - name: test-ocsp-rsa-malicious-no-responder-muststaple-server-does-not-staple-v6.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: no-responder + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "6.0" + tags: [ocsp, ocsp-rsa, "6.0"] + - name: test-ocsp-rsa-malicious-no-responder-muststaple-server-does-not-staple-v7.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: no-responder + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "7.0" + tags: [ocsp, ocsp-rsa, "7.0"] + - name: test-ocsp-rsa-malicious-no-responder-muststaple-server-does-not-staple-v8.0-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: no-responder + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: "8.0" + tags: [ocsp, ocsp-rsa, "8.0"] + - name: test-ocsp-rsa-malicious-no-responder-muststaple-server-does-not-staple-rapid-python3.9 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: no-responder + TEST_NAME: ocsp + PYTHON_VERSION: "3.9" + VERSION: rapid + tags: [ocsp, ocsp-rsa, rapid] + - name: test-ocsp-rsa-malicious-no-responder-muststaple-server-does-not-staple-latest-python3.13 + commands: + - func: run tests + vars: + ORCHESTRATION_FILE: rsa-basic-tls-ocsp-mustStaple-disableStapling.json + OCSP_SERVER_TYPE: no-responder + TEST_NAME: ocsp + PYTHON_VERSION: "3.13" + VERSION: latest + tags: [ocsp, ocsp-rsa, latest] + + # Oidc tests + - name: test-auth-oidc-default + commands: + - func: run tests + vars: + TEST_NAME: auth_oidc + SUB_TEST_NAME: default + tags: [auth_oidc] + - name: test-auth-oidc-azure + commands: + - func: run tests + vars: + TEST_NAME: auth_oidc + SUB_TEST_NAME: azure + tags: [auth_oidc, auth_oidc_remote] + - name: test-auth-oidc-gcp + commands: + - func: run tests + vars: + TEST_NAME: auth_oidc + SUB_TEST_NAME: gcp + tags: [auth_oidc, auth_oidc_remote] + - name: test-auth-oidc-eks + commands: + - func: run tests + vars: + TEST_NAME: auth_oidc + SUB_TEST_NAME: eks + tags: [auth_oidc, auth_oidc_remote] + - name: test-auth-oidc-aks + commands: + - func: run tests + vars: + TEST_NAME: auth_oidc + SUB_TEST_NAME: aks + tags: [auth_oidc, auth_oidc_remote] + - name: test-auth-oidc-gke + commands: + - func: run tests + vars: + TEST_NAME: auth_oidc + SUB_TEST_NAME: gke + tags: [auth_oidc, auth_oidc_remote] + + # Perf tests + - name: perf-8.0-standalone-ssl + commands: + - func: run server + vars: + VERSION: v8.0-perf + SSL: ssl + - func: run tests + vars: + TEST_NAME: perf + SUB_TEST_NAME: sync + - func: attach benchmark test results + - func: send dashboard data + tags: [perf] + - name: perf-8.0-standalone-ssl-async + commands: + - func: run server + vars: + VERSION: v8.0-perf + SSL: ssl + - func: run tests + vars: + TEST_NAME: perf + SUB_TEST_NAME: async + - func: attach benchmark test results + - func: send dashboard data + tags: [perf] + - name: perf-8.0-standalone + commands: + - func: run server + vars: + VERSION: v8.0-perf + SSL: nossl + - func: run tests + vars: + TEST_NAME: perf + SUB_TEST_NAME: sync + - func: attach benchmark test results + - func: send dashboard data + tags: [perf] + - name: perf-8.0-standalone-async + commands: + - func: run server + vars: + VERSION: v8.0-perf + SSL: nossl + - func: run tests + vars: + TEST_NAME: perf + SUB_TEST_NAME: async + - func: attach benchmark test results + - func: send dashboard data + tags: [perf] + + # Search index tests + - name: test-search-index-helpers + commands: + - func: assume ec2 role + - func: run server + vars: + TEST_NAME: search_index + - func: run tests + vars: + TEST_NAME: search_index + tags: [search_index] + + # Server tests + - name: test-4.0-standalone-auth-ssl-sync + commands: + - func: run server + vars: + VERSION: "4.0" + TOPOLOGY: server + AUTH: auth + SSL: ssl + - func: run tests + vars: + AUTH: auth + SSL: ssl + SYNC: sync + TEST_NAME: default_sync + tags: + - "4.0" + - standalone + - auth + - ssl + - sync + - name: test-4.0-standalone-auth-ssl-async + commands: + - func: run server + vars: + VERSION: "4.0" + TOPOLOGY: server + AUTH: auth + SSL: ssl + - func: run tests + vars: + AUTH: auth + SSL: ssl + SYNC: async + TEST_NAME: default_async + tags: + - "4.0" + - standalone + - auth + - ssl + - async + - name: test-4.0-standalone-auth-ssl-sync_async + commands: + - func: run server + vars: + VERSION: "4.0" + TOPOLOGY: server + AUTH: auth + SSL: ssl + - func: run tests + vars: + AUTH: auth + SSL: ssl + SYNC: sync_async + tags: + - "4.0" + - standalone + - auth + - ssl + - sync_async + - name: test-4.0-standalone-noauth-ssl-sync + commands: + - func: run server + vars: + VERSION: "4.0" + TOPOLOGY: server + AUTH: noauth + SSL: ssl + - func: run tests + vars: + AUTH: noauth + SSL: ssl + SYNC: sync + TEST_NAME: default_sync + tags: + - "4.0" + - standalone + - noauth + - ssl + - sync + - name: test-4.0-standalone-noauth-ssl-async + commands: + - func: run server + vars: + VERSION: "4.0" + TOPOLOGY: server + AUTH: noauth + SSL: ssl + - func: run tests + vars: + AUTH: noauth + SSL: ssl + SYNC: async + TEST_NAME: default_async + tags: + - "4.0" + - standalone + - noauth + - ssl + - async + - name: test-4.0-standalone-noauth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: - TOPOLOGY: sharded_cluster - AUTH: auth + VERSION: "4.0" + TOPOLOGY: server + AUTH: noauth SSL: ssl - LOAD_BALANCER: "true" - - func: run load-balancer - func: run tests vars: - AUTH: auth + AUTH: noauth SSL: ssl - test_loadbalancer: "true" - tags: [load-balancer, auth, ssl] - - name: test-load-balancer-noauth-ssl + SYNC: sync_async + tags: + - "4.0" + - standalone + - noauth + - ssl + - sync_async + - name: test-4.0-standalone-noauth-nossl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: - TOPOLOGY: sharded_cluster + VERSION: "4.0" + TOPOLOGY: server AUTH: noauth - SSL: ssl - LOAD_BALANCER: "true" - - func: run load-balancer + SSL: nossl - func: run tests vars: AUTH: noauth - SSL: ssl - test_loadbalancer: "true" - tags: [load-balancer, noauth, ssl] - - name: test-load-balancer-noauth-nossl + SSL: nossl + SYNC: sync + TEST_NAME: default_sync + tags: + - "4.0" + - standalone + - noauth + - nossl + - sync + - name: test-4.0-standalone-noauth-nossl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: - TOPOLOGY: sharded_cluster + VERSION: "4.0" + TOPOLOGY: server AUTH: noauth SSL: nossl - LOAD_BALANCER: "true" - - func: run load-balancer - func: run tests vars: AUTH: noauth SSL: nossl - test_loadbalancer: "true" - tags: [load-balancer, noauth, nossl] - - # Server tests - - name: test-4.0-standalone-auth-ssl-sync + SYNC: async + TEST_NAME: default_async + tags: + - "4.0" + - standalone + - noauth + - nossl + - async + - name: test-4.0-standalone-noauth-nossl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.0" TOPOLOGY: server + AUTH: noauth + SSL: nossl + - func: run tests + vars: + AUTH: noauth + SSL: nossl + SYNC: sync_async + tags: + - "4.0" + - standalone + - noauth + - nossl + - sync_async + - name: test-4.2-standalone-auth-ssl-sync + commands: + - func: run server + vars: + VERSION: "4.2" + TOPOLOGY: server AUTH: auth SSL: ssl - func: run tests @@ -60,18 +3441,18 @@ tasks: AUTH: auth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - - "4.0" + - "4.2" - standalone - auth - ssl - sync - - name: test-4.0-standalone-auth-ssl-async + - name: test-4.2-standalone-auth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: - VERSION: "4.0" + VERSION: "4.2" TOPOLOGY: server AUTH: auth SSL: ssl @@ -80,18 +3461,18 @@ tasks: AUTH: auth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - - "4.0" + - "4.2" - standalone - auth - ssl - async - - name: test-4.0-standalone-auth-ssl-sync_async + - name: test-4.2-standalone-auth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: - VERSION: "4.0" + VERSION: "4.2" TOPOLOGY: server AUTH: auth SSL: ssl @@ -100,18 +3481,17 @@ tasks: AUTH: auth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - - "4.0" + - "4.2" - standalone - auth - ssl - sync_async - - name: test-4.0-standalone-noauth-ssl-sync + - name: test-4.2-standalone-noauth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: - VERSION: "4.0" + VERSION: "4.2" TOPOLOGY: server AUTH: noauth SSL: ssl @@ -120,18 +3500,18 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - - "4.0" + - "4.2" - standalone - noauth - ssl - sync - - name: test-4.0-standalone-noauth-ssl-async + - name: test-4.2-standalone-noauth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: - VERSION: "4.0" + VERSION: "4.2" TOPOLOGY: server AUTH: noauth SSL: ssl @@ -140,18 +3520,18 @@ tasks: AUTH: noauth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - - "4.0" + - "4.2" - standalone - noauth - ssl - async - - name: test-4.0-standalone-noauth-ssl-sync_async + - name: test-4.2-standalone-noauth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: - VERSION: "4.0" + VERSION: "4.2" TOPOLOGY: server AUTH: noauth SSL: ssl @@ -160,18 +3540,17 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - - "4.0" + - "4.2" - standalone - noauth - ssl - sync_async - - name: test-4.0-standalone-noauth-nossl-sync + - name: test-4.2-standalone-noauth-nossl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: - VERSION: "4.0" + VERSION: "4.2" TOPOLOGY: server AUTH: noauth SSL: nossl @@ -180,18 +3559,18 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - - "4.0" + - "4.2" - standalone - noauth - nossl - sync - - name: test-4.0-standalone-noauth-nossl-async + - name: test-4.2-standalone-noauth-nossl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: - VERSION: "4.0" + VERSION: "4.2" TOPOLOGY: server AUTH: noauth SSL: nossl @@ -200,18 +3579,18 @@ tasks: AUTH: noauth SSL: nossl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - - "4.0" + - "4.2" - standalone - noauth - nossl - async - - name: test-4.0-standalone-noauth-nossl-sync_async + - name: test-4.2-standalone-noauth-nossl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: - VERSION: "4.0" + VERSION: "4.2" TOPOLOGY: server AUTH: noauth SSL: nossl @@ -220,16 +3599,15 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync_async - TEST_SUITES: "" tags: - - "4.0" + - "4.2" - standalone - noauth - nossl - sync_async - name: test-4.4-standalone-auth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.4" TOPOLOGY: server @@ -240,7 +3618,7 @@ tasks: AUTH: auth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "4.4" - standalone @@ -249,7 +3627,7 @@ tasks: - sync - name: test-4.4-standalone-auth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.4" TOPOLOGY: server @@ -260,7 +3638,7 @@ tasks: AUTH: auth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "4.4" - standalone @@ -269,7 +3647,7 @@ tasks: - async - name: test-4.4-standalone-auth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.4" TOPOLOGY: server @@ -280,7 +3658,6 @@ tasks: AUTH: auth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - "4.4" - standalone @@ -289,7 +3666,7 @@ tasks: - sync_async - name: test-4.4-standalone-noauth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.4" TOPOLOGY: server @@ -300,7 +3677,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "4.4" - standalone @@ -309,7 +3686,7 @@ tasks: - sync - name: test-4.4-standalone-noauth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.4" TOPOLOGY: server @@ -320,7 +3697,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "4.4" - standalone @@ -329,7 +3706,7 @@ tasks: - async - name: test-4.4-standalone-noauth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.4" TOPOLOGY: server @@ -340,7 +3717,6 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - "4.4" - standalone @@ -349,7 +3725,7 @@ tasks: - sync_async - name: test-4.4-standalone-noauth-nossl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.4" TOPOLOGY: server @@ -360,7 +3736,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "4.4" - standalone @@ -369,7 +3745,7 @@ tasks: - sync - name: test-4.4-standalone-noauth-nossl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.4" TOPOLOGY: server @@ -380,7 +3756,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "4.4" - standalone @@ -389,7 +3765,7 @@ tasks: - async - name: test-4.4-standalone-noauth-nossl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.4" TOPOLOGY: server @@ -400,7 +3776,6 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync_async - TEST_SUITES: "" tags: - "4.4" - standalone @@ -409,7 +3784,7 @@ tasks: - sync_async - name: test-5.0-standalone-auth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "5.0" TOPOLOGY: server @@ -420,7 +3795,7 @@ tasks: AUTH: auth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "5.0" - standalone @@ -429,7 +3804,7 @@ tasks: - sync - name: test-5.0-standalone-auth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "5.0" TOPOLOGY: server @@ -440,7 +3815,7 @@ tasks: AUTH: auth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "5.0" - standalone @@ -449,7 +3824,7 @@ tasks: - async - name: test-5.0-standalone-auth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "5.0" TOPOLOGY: server @@ -460,7 +3835,6 @@ tasks: AUTH: auth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - "5.0" - standalone @@ -469,7 +3843,7 @@ tasks: - sync_async - name: test-5.0-standalone-noauth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "5.0" TOPOLOGY: server @@ -480,7 +3854,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "5.0" - standalone @@ -489,7 +3863,7 @@ tasks: - sync - name: test-5.0-standalone-noauth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "5.0" TOPOLOGY: server @@ -500,7 +3874,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "5.0" - standalone @@ -509,7 +3883,7 @@ tasks: - async - name: test-5.0-standalone-noauth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "5.0" TOPOLOGY: server @@ -520,7 +3894,6 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - "5.0" - standalone @@ -529,7 +3902,7 @@ tasks: - sync_async - name: test-5.0-standalone-noauth-nossl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "5.0" TOPOLOGY: server @@ -540,7 +3913,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "5.0" - standalone @@ -549,7 +3922,7 @@ tasks: - sync - name: test-5.0-standalone-noauth-nossl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "5.0" TOPOLOGY: server @@ -560,7 +3933,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "5.0" - standalone @@ -569,7 +3942,7 @@ tasks: - async - name: test-5.0-standalone-noauth-nossl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "5.0" TOPOLOGY: server @@ -580,7 +3953,6 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync_async - TEST_SUITES: "" tags: - "5.0" - standalone @@ -589,7 +3961,7 @@ tasks: - sync_async - name: test-6.0-standalone-auth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "6.0" TOPOLOGY: server @@ -600,7 +3972,7 @@ tasks: AUTH: auth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "6.0" - standalone @@ -609,7 +3981,7 @@ tasks: - sync - name: test-6.0-standalone-auth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "6.0" TOPOLOGY: server @@ -620,7 +3992,7 @@ tasks: AUTH: auth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "6.0" - standalone @@ -629,7 +4001,7 @@ tasks: - async - name: test-6.0-standalone-auth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "6.0" TOPOLOGY: server @@ -640,7 +4012,6 @@ tasks: AUTH: auth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - "6.0" - standalone @@ -649,7 +4020,7 @@ tasks: - sync_async - name: test-6.0-standalone-noauth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "6.0" TOPOLOGY: server @@ -660,7 +4031,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "6.0" - standalone @@ -669,7 +4040,7 @@ tasks: - sync - name: test-6.0-standalone-noauth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "6.0" TOPOLOGY: server @@ -680,7 +4051,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "6.0" - standalone @@ -689,7 +4060,7 @@ tasks: - async - name: test-6.0-standalone-noauth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "6.0" TOPOLOGY: server @@ -700,7 +4071,6 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - "6.0" - standalone @@ -709,7 +4079,7 @@ tasks: - sync_async - name: test-6.0-standalone-noauth-nossl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "6.0" TOPOLOGY: server @@ -720,7 +4090,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "6.0" - standalone @@ -729,7 +4099,7 @@ tasks: - sync - name: test-6.0-standalone-noauth-nossl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "6.0" TOPOLOGY: server @@ -740,7 +4110,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "6.0" - standalone @@ -749,7 +4119,7 @@ tasks: - async - name: test-6.0-standalone-noauth-nossl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "6.0" TOPOLOGY: server @@ -760,7 +4130,6 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync_async - TEST_SUITES: "" tags: - "6.0" - standalone @@ -769,7 +4138,7 @@ tasks: - sync_async - name: test-7.0-standalone-auth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "7.0" TOPOLOGY: server @@ -780,7 +4149,7 @@ tasks: AUTH: auth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "7.0" - standalone @@ -789,7 +4158,7 @@ tasks: - sync - name: test-7.0-standalone-auth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "7.0" TOPOLOGY: server @@ -800,7 +4169,7 @@ tasks: AUTH: auth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "7.0" - standalone @@ -809,7 +4178,7 @@ tasks: - async - name: test-7.0-standalone-auth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "7.0" TOPOLOGY: server @@ -820,7 +4189,6 @@ tasks: AUTH: auth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - "7.0" - standalone @@ -829,7 +4197,7 @@ tasks: - sync_async - name: test-7.0-standalone-noauth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "7.0" TOPOLOGY: server @@ -840,7 +4208,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "7.0" - standalone @@ -849,7 +4217,7 @@ tasks: - sync - name: test-7.0-standalone-noauth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "7.0" TOPOLOGY: server @@ -860,7 +4228,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "7.0" - standalone @@ -869,7 +4237,7 @@ tasks: - async - name: test-7.0-standalone-noauth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "7.0" TOPOLOGY: server @@ -880,7 +4248,6 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - "7.0" - standalone @@ -889,7 +4256,7 @@ tasks: - sync_async - name: test-7.0-standalone-noauth-nossl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "7.0" TOPOLOGY: server @@ -900,7 +4267,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "7.0" - standalone @@ -909,7 +4276,7 @@ tasks: - sync - name: test-7.0-standalone-noauth-nossl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "7.0" TOPOLOGY: server @@ -920,7 +4287,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "7.0" - standalone @@ -929,7 +4296,7 @@ tasks: - async - name: test-7.0-standalone-noauth-nossl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "7.0" TOPOLOGY: server @@ -940,7 +4307,6 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync_async - TEST_SUITES: "" tags: - "7.0" - standalone @@ -949,7 +4315,7 @@ tasks: - sync_async - name: test-8.0-standalone-auth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "8.0" TOPOLOGY: server @@ -960,7 +4326,7 @@ tasks: AUTH: auth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "8.0" - standalone @@ -969,7 +4335,7 @@ tasks: - sync - name: test-8.0-standalone-auth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "8.0" TOPOLOGY: server @@ -980,7 +4346,7 @@ tasks: AUTH: auth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "8.0" - standalone @@ -989,7 +4355,7 @@ tasks: - async - name: test-8.0-standalone-auth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "8.0" TOPOLOGY: server @@ -1000,7 +4366,6 @@ tasks: AUTH: auth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - "8.0" - standalone @@ -1009,7 +4374,7 @@ tasks: - sync_async - name: test-8.0-standalone-noauth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "8.0" TOPOLOGY: server @@ -1020,7 +4385,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "8.0" - standalone @@ -1029,7 +4394,7 @@ tasks: - sync - name: test-8.0-standalone-noauth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "8.0" TOPOLOGY: server @@ -1040,7 +4405,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "8.0" - standalone @@ -1049,7 +4414,7 @@ tasks: - async - name: test-8.0-standalone-noauth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "8.0" TOPOLOGY: server @@ -1060,7 +4425,6 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - "8.0" - standalone @@ -1069,7 +4433,7 @@ tasks: - sync_async - name: test-8.0-standalone-noauth-nossl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "8.0" TOPOLOGY: server @@ -1080,7 +4444,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "8.0" - standalone @@ -1089,7 +4453,7 @@ tasks: - sync - name: test-8.0-standalone-noauth-nossl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "8.0" TOPOLOGY: server @@ -1100,7 +4464,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "8.0" - standalone @@ -1109,7 +4473,7 @@ tasks: - async - name: test-8.0-standalone-noauth-nossl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "8.0" TOPOLOGY: server @@ -1120,7 +4484,6 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync_async - TEST_SUITES: "" tags: - "8.0" - standalone @@ -1129,7 +4492,7 @@ tasks: - sync_async - name: test-rapid-standalone-auth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: rapid TOPOLOGY: server @@ -1140,7 +4503,7 @@ tasks: AUTH: auth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - rapid - standalone @@ -1149,7 +4512,7 @@ tasks: - sync - name: test-rapid-standalone-auth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: rapid TOPOLOGY: server @@ -1160,7 +4523,7 @@ tasks: AUTH: auth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - rapid - standalone @@ -1169,7 +4532,7 @@ tasks: - async - name: test-rapid-standalone-auth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: rapid TOPOLOGY: server @@ -1180,7 +4543,6 @@ tasks: AUTH: auth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - rapid - standalone @@ -1189,7 +4551,7 @@ tasks: - sync_async - name: test-rapid-standalone-noauth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: rapid TOPOLOGY: server @@ -1200,7 +4562,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - rapid - standalone @@ -1209,7 +4571,7 @@ tasks: - sync - name: test-rapid-standalone-noauth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: rapid TOPOLOGY: server @@ -1220,7 +4582,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - rapid - standalone @@ -1229,7 +4591,7 @@ tasks: - async - name: test-rapid-standalone-noauth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: rapid TOPOLOGY: server @@ -1240,7 +4602,6 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - rapid - standalone @@ -1249,7 +4610,7 @@ tasks: - sync_async - name: test-rapid-standalone-noauth-nossl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: rapid TOPOLOGY: server @@ -1260,7 +4621,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - rapid - standalone @@ -1269,7 +4630,7 @@ tasks: - sync - name: test-rapid-standalone-noauth-nossl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: rapid TOPOLOGY: server @@ -1280,7 +4641,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - rapid - standalone @@ -1289,7 +4650,7 @@ tasks: - async - name: test-rapid-standalone-noauth-nossl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: rapid TOPOLOGY: server @@ -1300,7 +4661,6 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync_async - TEST_SUITES: "" tags: - rapid - standalone @@ -1309,10 +4669,187 @@ tasks: - sync_async - name: test-latest-standalone-auth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server + vars: + VERSION: latest + TOPOLOGY: server + AUTH: auth + SSL: ssl + - func: run tests + vars: + AUTH: auth + SSL: ssl + SYNC: sync + TEST_NAME: default_sync + tags: + - latest + - standalone + - auth + - ssl + - sync + - name: test-latest-standalone-auth-ssl-async + commands: + - func: run server + vars: + VERSION: latest + TOPOLOGY: server + AUTH: auth + SSL: ssl + - func: run tests + vars: + AUTH: auth + SSL: ssl + SYNC: async + TEST_NAME: default_async + tags: + - latest + - standalone + - auth + - ssl + - async + - name: test-latest-standalone-auth-ssl-sync_async + commands: + - func: run server + vars: + VERSION: latest + TOPOLOGY: server + AUTH: auth + SSL: ssl + - func: run tests + vars: + AUTH: auth + SSL: ssl + SYNC: sync_async + tags: + - latest + - standalone + - auth + - ssl + - sync_async + - name: test-latest-standalone-noauth-ssl-sync + commands: + - func: run server + vars: + VERSION: latest + TOPOLOGY: server + AUTH: noauth + SSL: ssl + - func: run tests + vars: + AUTH: noauth + SSL: ssl + SYNC: sync + TEST_NAME: default_sync + tags: + - latest + - standalone + - noauth + - ssl + - sync + - name: test-latest-standalone-noauth-ssl-async + commands: + - func: run server + vars: + VERSION: latest + TOPOLOGY: server + AUTH: noauth + SSL: ssl + - func: run tests + vars: + AUTH: noauth + SSL: ssl + SYNC: async + TEST_NAME: default_async + tags: + - latest + - standalone + - noauth + - ssl + - async + - name: test-latest-standalone-noauth-ssl-sync_async + commands: + - func: run server + vars: + VERSION: latest + TOPOLOGY: server + AUTH: noauth + SSL: ssl + - func: run tests + vars: + AUTH: noauth + SSL: ssl + SYNC: sync_async + tags: + - latest + - standalone + - noauth + - ssl + - sync_async + - name: test-latest-standalone-noauth-nossl-sync + commands: + - func: run server + vars: + VERSION: latest + TOPOLOGY: server + AUTH: noauth + SSL: nossl + - func: run tests + vars: + AUTH: noauth + SSL: nossl + SYNC: sync + TEST_NAME: default_sync + tags: + - latest + - standalone + - noauth + - nossl + - sync + - name: test-latest-standalone-noauth-nossl-async + commands: + - func: run server + vars: + VERSION: latest + TOPOLOGY: server + AUTH: noauth + SSL: nossl + - func: run tests + vars: + AUTH: noauth + SSL: nossl + SYNC: async + TEST_NAME: default_async + tags: + - latest + - standalone + - noauth + - nossl + - async + - name: test-latest-standalone-noauth-nossl-sync_async + commands: + - func: run server + vars: + VERSION: latest + TOPOLOGY: server + AUTH: noauth + SSL: nossl + - func: run tests + vars: + AUTH: noauth + SSL: nossl + SYNC: sync_async + tags: + - latest + - standalone + - noauth + - nossl + - sync_async + - name: test-4.0-replica_set-auth-ssl-sync + commands: + - func: run server vars: - VERSION: latest - TOPOLOGY: server + VERSION: "4.0" + TOPOLOGY: replica_set AUTH: auth SSL: ssl - func: run tests @@ -1320,19 +4857,19 @@ tasks: AUTH: auth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - - latest - - standalone + - "4.0" + - replica_set - auth - ssl - sync - - name: test-latest-standalone-auth-ssl-async + - name: test-4.0-replica_set-auth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: - VERSION: latest - TOPOLOGY: server + VERSION: "4.0" + TOPOLOGY: replica_set AUTH: auth SSL: ssl - func: run tests @@ -1340,19 +4877,19 @@ tasks: AUTH: auth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - - latest - - standalone + - "4.0" + - replica_set - auth - ssl - async - - name: test-latest-standalone-auth-ssl-sync_async + - name: test-4.0-replica_set-auth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: - VERSION: latest - TOPOLOGY: server + VERSION: "4.0" + TOPOLOGY: replica_set AUTH: auth SSL: ssl - func: run tests @@ -1360,19 +4897,18 @@ tasks: AUTH: auth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - - latest - - standalone + - "4.0" + - replica_set - auth - ssl - sync_async - - name: test-latest-standalone-noauth-ssl-sync + - name: test-4.0-replica_set-noauth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: - VERSION: latest - TOPOLOGY: server + VERSION: "4.0" + TOPOLOGY: replica_set AUTH: noauth SSL: ssl - func: run tests @@ -1380,19 +4916,19 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - - latest - - standalone + - "4.0" + - replica_set - noauth - ssl - sync - - name: test-latest-standalone-noauth-ssl-async + - name: test-4.0-replica_set-noauth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: - VERSION: latest - TOPOLOGY: server + VERSION: "4.0" + TOPOLOGY: replica_set AUTH: noauth SSL: ssl - func: run tests @@ -1400,19 +4936,19 @@ tasks: AUTH: noauth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - - latest - - standalone + - "4.0" + - replica_set - noauth - ssl - async - - name: test-latest-standalone-noauth-ssl-sync_async + - name: test-4.0-replica_set-noauth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: - VERSION: latest - TOPOLOGY: server + VERSION: "4.0" + TOPOLOGY: replica_set AUTH: noauth SSL: ssl - func: run tests @@ -1420,19 +4956,18 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - - latest - - standalone + - "4.0" + - replica_set - noauth - ssl - sync_async - - name: test-latest-standalone-noauth-nossl-sync + - name: test-4.0-replica_set-noauth-nossl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: - VERSION: latest - TOPOLOGY: server + VERSION: "4.0" + TOPOLOGY: replica_set AUTH: noauth SSL: nossl - func: run tests @@ -1440,19 +4975,19 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - - latest - - standalone + - "4.0" + - replica_set - noauth - nossl - sync - - name: test-latest-standalone-noauth-nossl-async + - name: test-4.0-replica_set-noauth-nossl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: - VERSION: latest - TOPOLOGY: server + VERSION: "4.0" + TOPOLOGY: replica_set AUTH: noauth SSL: nossl - func: run tests @@ -1460,19 +4995,19 @@ tasks: AUTH: noauth SSL: nossl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - - latest - - standalone + - "4.0" + - replica_set - noauth - nossl - async - - name: test-latest-standalone-noauth-nossl-sync_async + - name: test-4.0-replica_set-noauth-nossl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: - VERSION: latest - TOPOLOGY: server + VERSION: "4.0" + TOPOLOGY: replica_set AUTH: noauth SSL: nossl - func: run tests @@ -1480,18 +5015,17 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync_async - TEST_SUITES: "" tags: - - latest - - standalone + - "4.0" + - replica_set - noauth - nossl - sync_async - - name: test-4.0-replica_set-auth-ssl-sync + - name: test-4.2-replica_set-auth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: - VERSION: "4.0" + VERSION: "4.2" TOPOLOGY: replica_set AUTH: auth SSL: ssl @@ -1500,18 +5034,18 @@ tasks: AUTH: auth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - - "4.0" + - "4.2" - replica_set - auth - ssl - sync - - name: test-4.0-replica_set-auth-ssl-async + - name: test-4.2-replica_set-auth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: - VERSION: "4.0" + VERSION: "4.2" TOPOLOGY: replica_set AUTH: auth SSL: ssl @@ -1520,18 +5054,18 @@ tasks: AUTH: auth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - - "4.0" + - "4.2" - replica_set - auth - ssl - async - - name: test-4.0-replica_set-auth-ssl-sync_async + - name: test-4.2-replica_set-auth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: - VERSION: "4.0" + VERSION: "4.2" TOPOLOGY: replica_set AUTH: auth SSL: ssl @@ -1540,18 +5074,17 @@ tasks: AUTH: auth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - - "4.0" + - "4.2" - replica_set - auth - ssl - sync_async - - name: test-4.0-replica_set-noauth-ssl-sync + - name: test-4.2-replica_set-noauth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: - VERSION: "4.0" + VERSION: "4.2" TOPOLOGY: replica_set AUTH: noauth SSL: ssl @@ -1560,18 +5093,18 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - - "4.0" + - "4.2" - replica_set - noauth - ssl - sync - - name: test-4.0-replica_set-noauth-ssl-async + - name: test-4.2-replica_set-noauth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: - VERSION: "4.0" + VERSION: "4.2" TOPOLOGY: replica_set AUTH: noauth SSL: ssl @@ -1580,18 +5113,18 @@ tasks: AUTH: noauth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - - "4.0" + - "4.2" - replica_set - noauth - ssl - async - - name: test-4.0-replica_set-noauth-ssl-sync_async + - name: test-4.2-replica_set-noauth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: - VERSION: "4.0" + VERSION: "4.2" TOPOLOGY: replica_set AUTH: noauth SSL: ssl @@ -1600,18 +5133,17 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - - "4.0" + - "4.2" - replica_set - noauth - ssl - sync_async - - name: test-4.0-replica_set-noauth-nossl-sync + - name: test-4.2-replica_set-noauth-nossl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: - VERSION: "4.0" + VERSION: "4.2" TOPOLOGY: replica_set AUTH: noauth SSL: nossl @@ -1620,18 +5152,18 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - - "4.0" + - "4.2" - replica_set - noauth - nossl - sync - - name: test-4.0-replica_set-noauth-nossl-async + - name: test-4.2-replica_set-noauth-nossl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: - VERSION: "4.0" + VERSION: "4.2" TOPOLOGY: replica_set AUTH: noauth SSL: nossl @@ -1640,18 +5172,18 @@ tasks: AUTH: noauth SSL: nossl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - - "4.0" + - "4.2" - replica_set - noauth - nossl - async - - name: test-4.0-replica_set-noauth-nossl-sync_async + - name: test-4.2-replica_set-noauth-nossl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: - VERSION: "4.0" + VERSION: "4.2" TOPOLOGY: replica_set AUTH: noauth SSL: nossl @@ -1660,16 +5192,15 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync_async - TEST_SUITES: "" tags: - - "4.0" + - "4.2" - replica_set - noauth - nossl - sync_async - name: test-4.4-replica_set-auth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.4" TOPOLOGY: replica_set @@ -1680,7 +5211,7 @@ tasks: AUTH: auth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "4.4" - replica_set @@ -1689,7 +5220,7 @@ tasks: - sync - name: test-4.4-replica_set-auth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.4" TOPOLOGY: replica_set @@ -1700,7 +5231,7 @@ tasks: AUTH: auth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "4.4" - replica_set @@ -1709,7 +5240,7 @@ tasks: - async - name: test-4.4-replica_set-auth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.4" TOPOLOGY: replica_set @@ -1720,7 +5251,6 @@ tasks: AUTH: auth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - "4.4" - replica_set @@ -1729,7 +5259,7 @@ tasks: - sync_async - name: test-4.4-replica_set-noauth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.4" TOPOLOGY: replica_set @@ -1740,7 +5270,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "4.4" - replica_set @@ -1749,7 +5279,7 @@ tasks: - sync - name: test-4.4-replica_set-noauth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.4" TOPOLOGY: replica_set @@ -1760,7 +5290,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "4.4" - replica_set @@ -1769,7 +5299,7 @@ tasks: - async - name: test-4.4-replica_set-noauth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.4" TOPOLOGY: replica_set @@ -1780,7 +5310,6 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - "4.4" - replica_set @@ -1789,7 +5318,7 @@ tasks: - sync_async - name: test-4.4-replica_set-noauth-nossl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.4" TOPOLOGY: replica_set @@ -1800,7 +5329,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "4.4" - replica_set @@ -1809,7 +5338,7 @@ tasks: - sync - name: test-4.4-replica_set-noauth-nossl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.4" TOPOLOGY: replica_set @@ -1820,7 +5349,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "4.4" - replica_set @@ -1829,7 +5358,7 @@ tasks: - async - name: test-4.4-replica_set-noauth-nossl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.4" TOPOLOGY: replica_set @@ -1840,7 +5369,6 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync_async - TEST_SUITES: "" tags: - "4.4" - replica_set @@ -1849,7 +5377,7 @@ tasks: - sync_async - name: test-5.0-replica_set-auth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "5.0" TOPOLOGY: replica_set @@ -1860,7 +5388,7 @@ tasks: AUTH: auth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "5.0" - replica_set @@ -1869,7 +5397,7 @@ tasks: - sync - name: test-5.0-replica_set-auth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "5.0" TOPOLOGY: replica_set @@ -1880,7 +5408,7 @@ tasks: AUTH: auth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "5.0" - replica_set @@ -1889,7 +5417,7 @@ tasks: - async - name: test-5.0-replica_set-auth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "5.0" TOPOLOGY: replica_set @@ -1900,7 +5428,6 @@ tasks: AUTH: auth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - "5.0" - replica_set @@ -1909,7 +5436,7 @@ tasks: - sync_async - name: test-5.0-replica_set-noauth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "5.0" TOPOLOGY: replica_set @@ -1920,7 +5447,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "5.0" - replica_set @@ -1929,7 +5456,7 @@ tasks: - sync - name: test-5.0-replica_set-noauth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "5.0" TOPOLOGY: replica_set @@ -1940,7 +5467,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "5.0" - replica_set @@ -1949,7 +5476,7 @@ tasks: - async - name: test-5.0-replica_set-noauth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "5.0" TOPOLOGY: replica_set @@ -1960,7 +5487,6 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - "5.0" - replica_set @@ -1969,7 +5495,7 @@ tasks: - sync_async - name: test-5.0-replica_set-noauth-nossl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "5.0" TOPOLOGY: replica_set @@ -1980,7 +5506,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "5.0" - replica_set @@ -1989,7 +5515,7 @@ tasks: - sync - name: test-5.0-replica_set-noauth-nossl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "5.0" TOPOLOGY: replica_set @@ -2000,7 +5526,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "5.0" - replica_set @@ -2009,7 +5535,7 @@ tasks: - async - name: test-5.0-replica_set-noauth-nossl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "5.0" TOPOLOGY: replica_set @@ -2020,7 +5546,6 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync_async - TEST_SUITES: "" tags: - "5.0" - replica_set @@ -2029,7 +5554,7 @@ tasks: - sync_async - name: test-6.0-replica_set-auth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "6.0" TOPOLOGY: replica_set @@ -2040,7 +5565,7 @@ tasks: AUTH: auth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "6.0" - replica_set @@ -2049,7 +5574,7 @@ tasks: - sync - name: test-6.0-replica_set-auth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "6.0" TOPOLOGY: replica_set @@ -2060,7 +5585,7 @@ tasks: AUTH: auth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "6.0" - replica_set @@ -2069,7 +5594,7 @@ tasks: - async - name: test-6.0-replica_set-auth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "6.0" TOPOLOGY: replica_set @@ -2080,7 +5605,6 @@ tasks: AUTH: auth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - "6.0" - replica_set @@ -2089,7 +5613,7 @@ tasks: - sync_async - name: test-6.0-replica_set-noauth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "6.0" TOPOLOGY: replica_set @@ -2100,7 +5624,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "6.0" - replica_set @@ -2109,7 +5633,7 @@ tasks: - sync - name: test-6.0-replica_set-noauth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "6.0" TOPOLOGY: replica_set @@ -2120,7 +5644,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "6.0" - replica_set @@ -2129,7 +5653,7 @@ tasks: - async - name: test-6.0-replica_set-noauth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "6.0" TOPOLOGY: replica_set @@ -2140,7 +5664,6 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - "6.0" - replica_set @@ -2149,7 +5672,7 @@ tasks: - sync_async - name: test-6.0-replica_set-noauth-nossl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "6.0" TOPOLOGY: replica_set @@ -2160,7 +5683,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "6.0" - replica_set @@ -2169,7 +5692,7 @@ tasks: - sync - name: test-6.0-replica_set-noauth-nossl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "6.0" TOPOLOGY: replica_set @@ -2180,7 +5703,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "6.0" - replica_set @@ -2189,7 +5712,7 @@ tasks: - async - name: test-6.0-replica_set-noauth-nossl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "6.0" TOPOLOGY: replica_set @@ -2200,7 +5723,6 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync_async - TEST_SUITES: "" tags: - "6.0" - replica_set @@ -2209,7 +5731,7 @@ tasks: - sync_async - name: test-7.0-replica_set-auth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "7.0" TOPOLOGY: replica_set @@ -2220,7 +5742,7 @@ tasks: AUTH: auth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "7.0" - replica_set @@ -2229,7 +5751,7 @@ tasks: - sync - name: test-7.0-replica_set-auth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "7.0" TOPOLOGY: replica_set @@ -2240,7 +5762,7 @@ tasks: AUTH: auth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "7.0" - replica_set @@ -2249,7 +5771,7 @@ tasks: - async - name: test-7.0-replica_set-auth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "7.0" TOPOLOGY: replica_set @@ -2260,7 +5782,6 @@ tasks: AUTH: auth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - "7.0" - replica_set @@ -2269,7 +5790,7 @@ tasks: - sync_async - name: test-7.0-replica_set-noauth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "7.0" TOPOLOGY: replica_set @@ -2280,7 +5801,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "7.0" - replica_set @@ -2289,7 +5810,7 @@ tasks: - sync - name: test-7.0-replica_set-noauth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "7.0" TOPOLOGY: replica_set @@ -2300,7 +5821,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "7.0" - replica_set @@ -2309,7 +5830,7 @@ tasks: - async - name: test-7.0-replica_set-noauth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "7.0" TOPOLOGY: replica_set @@ -2320,7 +5841,6 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - "7.0" - replica_set @@ -2329,7 +5849,7 @@ tasks: - sync_async - name: test-7.0-replica_set-noauth-nossl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "7.0" TOPOLOGY: replica_set @@ -2340,7 +5860,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "7.0" - replica_set @@ -2349,7 +5869,7 @@ tasks: - sync - name: test-7.0-replica_set-noauth-nossl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "7.0" TOPOLOGY: replica_set @@ -2360,7 +5880,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "7.0" - replica_set @@ -2369,7 +5889,7 @@ tasks: - async - name: test-7.0-replica_set-noauth-nossl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "7.0" TOPOLOGY: replica_set @@ -2380,7 +5900,6 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync_async - TEST_SUITES: "" tags: - "7.0" - replica_set @@ -2389,7 +5908,7 @@ tasks: - sync_async - name: test-8.0-replica_set-auth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "8.0" TOPOLOGY: replica_set @@ -2400,7 +5919,7 @@ tasks: AUTH: auth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "8.0" - replica_set @@ -2409,7 +5928,7 @@ tasks: - sync - name: test-8.0-replica_set-auth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "8.0" TOPOLOGY: replica_set @@ -2420,7 +5939,7 @@ tasks: AUTH: auth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "8.0" - replica_set @@ -2429,7 +5948,7 @@ tasks: - async - name: test-8.0-replica_set-auth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "8.0" TOPOLOGY: replica_set @@ -2440,7 +5959,6 @@ tasks: AUTH: auth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - "8.0" - replica_set @@ -2449,7 +5967,7 @@ tasks: - sync_async - name: test-8.0-replica_set-noauth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "8.0" TOPOLOGY: replica_set @@ -2460,7 +5978,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "8.0" - replica_set @@ -2469,7 +5987,7 @@ tasks: - sync - name: test-8.0-replica_set-noauth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "8.0" TOPOLOGY: replica_set @@ -2480,7 +5998,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "8.0" - replica_set @@ -2489,7 +6007,7 @@ tasks: - async - name: test-8.0-replica_set-noauth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "8.0" TOPOLOGY: replica_set @@ -2500,7 +6018,6 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - "8.0" - replica_set @@ -2509,7 +6026,7 @@ tasks: - sync_async - name: test-8.0-replica_set-noauth-nossl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "8.0" TOPOLOGY: replica_set @@ -2520,7 +6037,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "8.0" - replica_set @@ -2529,7 +6046,7 @@ tasks: - sync - name: test-8.0-replica_set-noauth-nossl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "8.0" TOPOLOGY: replica_set @@ -2540,7 +6057,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "8.0" - replica_set @@ -2549,7 +6066,7 @@ tasks: - async - name: test-8.0-replica_set-noauth-nossl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "8.0" TOPOLOGY: replica_set @@ -2560,7 +6077,6 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync_async - TEST_SUITES: "" tags: - "8.0" - replica_set @@ -2569,7 +6085,7 @@ tasks: - sync_async - name: test-rapid-replica_set-auth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: rapid TOPOLOGY: replica_set @@ -2580,7 +6096,7 @@ tasks: AUTH: auth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - rapid - replica_set @@ -2589,7 +6105,7 @@ tasks: - sync - name: test-rapid-replica_set-auth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: rapid TOPOLOGY: replica_set @@ -2600,7 +6116,7 @@ tasks: AUTH: auth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - rapid - replica_set @@ -2609,7 +6125,7 @@ tasks: - async - name: test-rapid-replica_set-auth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: rapid TOPOLOGY: replica_set @@ -2620,7 +6136,6 @@ tasks: AUTH: auth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - rapid - replica_set @@ -2629,7 +6144,7 @@ tasks: - sync_async - name: test-rapid-replica_set-noauth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: rapid TOPOLOGY: replica_set @@ -2640,7 +6155,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - rapid - replica_set @@ -2649,7 +6164,7 @@ tasks: - sync - name: test-rapid-replica_set-noauth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: rapid TOPOLOGY: replica_set @@ -2660,7 +6175,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - rapid - replica_set @@ -2669,7 +6184,7 @@ tasks: - async - name: test-rapid-replica_set-noauth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: rapid TOPOLOGY: replica_set @@ -2680,7 +6195,6 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - rapid - replica_set @@ -2689,7 +6203,7 @@ tasks: - sync_async - name: test-rapid-replica_set-noauth-nossl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: rapid TOPOLOGY: replica_set @@ -2700,7 +6214,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - rapid - replica_set @@ -2709,7 +6223,7 @@ tasks: - sync - name: test-rapid-replica_set-noauth-nossl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: rapid TOPOLOGY: replica_set @@ -2720,7 +6234,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - rapid - replica_set @@ -2729,7 +6243,7 @@ tasks: - async - name: test-rapid-replica_set-noauth-nossl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: rapid TOPOLOGY: replica_set @@ -2740,7 +6254,6 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync_async - TEST_SUITES: "" tags: - rapid - replica_set @@ -2749,7 +6262,7 @@ tasks: - sync_async - name: test-latest-replica_set-auth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: latest TOPOLOGY: replica_set @@ -2760,7 +6273,7 @@ tasks: AUTH: auth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - latest - replica_set @@ -2769,7 +6282,7 @@ tasks: - sync - name: test-latest-replica_set-auth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: latest TOPOLOGY: replica_set @@ -2780,7 +6293,7 @@ tasks: AUTH: auth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - latest - replica_set @@ -2789,7 +6302,7 @@ tasks: - async - name: test-latest-replica_set-auth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: latest TOPOLOGY: replica_set @@ -2800,7 +6313,6 @@ tasks: AUTH: auth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - latest - replica_set @@ -2809,7 +6321,7 @@ tasks: - sync_async - name: test-latest-replica_set-noauth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: latest TOPOLOGY: replica_set @@ -2820,7 +6332,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - latest - replica_set @@ -2829,7 +6341,7 @@ tasks: - sync - name: test-latest-replica_set-noauth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: latest TOPOLOGY: replica_set @@ -2840,7 +6352,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - latest - replica_set @@ -2849,7 +6361,7 @@ tasks: - async - name: test-latest-replica_set-noauth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: latest TOPOLOGY: replica_set @@ -2860,7 +6372,6 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - latest - replica_set @@ -2869,7 +6380,7 @@ tasks: - sync_async - name: test-latest-replica_set-noauth-nossl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: latest TOPOLOGY: replica_set @@ -2880,7 +6391,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - latest - replica_set @@ -2889,7 +6400,7 @@ tasks: - sync - name: test-latest-replica_set-noauth-nossl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: latest TOPOLOGY: replica_set @@ -2900,7 +6411,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - latest - replica_set @@ -2909,7 +6420,7 @@ tasks: - async - name: test-latest-replica_set-noauth-nossl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: latest TOPOLOGY: replica_set @@ -2920,7 +6431,6 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync_async - TEST_SUITES: "" tags: - latest - replica_set @@ -2929,7 +6439,7 @@ tasks: - sync_async - name: test-4.0-sharded_cluster-auth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.0" TOPOLOGY: sharded_cluster @@ -2940,7 +6450,7 @@ tasks: AUTH: auth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "4.0" - sharded_cluster @@ -2949,7 +6459,7 @@ tasks: - sync - name: test-4.0-sharded_cluster-auth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.0" TOPOLOGY: sharded_cluster @@ -2960,7 +6470,7 @@ tasks: AUTH: auth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "4.0" - sharded_cluster @@ -2969,7 +6479,7 @@ tasks: - async - name: test-4.0-sharded_cluster-auth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.0" TOPOLOGY: sharded_cluster @@ -2980,7 +6490,6 @@ tasks: AUTH: auth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - "4.0" - sharded_cluster @@ -2989,7 +6498,7 @@ tasks: - sync_async - name: test-4.0-sharded_cluster-noauth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.0" TOPOLOGY: sharded_cluster @@ -3000,7 +6509,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "4.0" - sharded_cluster @@ -3009,7 +6518,7 @@ tasks: - sync - name: test-4.0-sharded_cluster-noauth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.0" TOPOLOGY: sharded_cluster @@ -3020,7 +6529,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "4.0" - sharded_cluster @@ -3029,7 +6538,7 @@ tasks: - async - name: test-4.0-sharded_cluster-noauth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.0" TOPOLOGY: sharded_cluster @@ -3040,7 +6549,6 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - "4.0" - sharded_cluster @@ -3049,7 +6557,7 @@ tasks: - sync_async - name: test-4.0-sharded_cluster-noauth-nossl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.0" TOPOLOGY: sharded_cluster @@ -3060,7 +6568,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "4.0" - sharded_cluster @@ -3069,7 +6577,7 @@ tasks: - sync - name: test-4.0-sharded_cluster-noauth-nossl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.0" TOPOLOGY: sharded_cluster @@ -3080,7 +6588,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "4.0" - sharded_cluster @@ -3089,7 +6597,7 @@ tasks: - async - name: test-4.0-sharded_cluster-noauth-nossl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.0" TOPOLOGY: sharded_cluster @@ -3100,16 +6608,192 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync_async - TEST_SUITES: "" tags: - "4.0" - sharded_cluster - noauth - nossl - sync_async + - name: test-4.2-sharded_cluster-auth-ssl-sync + commands: + - func: run server + vars: + VERSION: "4.2" + TOPOLOGY: sharded_cluster + AUTH: auth + SSL: ssl + - func: run tests + vars: + AUTH: auth + SSL: ssl + SYNC: sync + TEST_NAME: default_sync + tags: + - "4.2" + - sharded_cluster + - auth + - ssl + - sync + - name: test-4.2-sharded_cluster-auth-ssl-async + commands: + - func: run server + vars: + VERSION: "4.2" + TOPOLOGY: sharded_cluster + AUTH: auth + SSL: ssl + - func: run tests + vars: + AUTH: auth + SSL: ssl + SYNC: async + TEST_NAME: default_async + tags: + - "4.2" + - sharded_cluster + - auth + - ssl + - async + - name: test-4.2-sharded_cluster-auth-ssl-sync_async + commands: + - func: run server + vars: + VERSION: "4.2" + TOPOLOGY: sharded_cluster + AUTH: auth + SSL: ssl + - func: run tests + vars: + AUTH: auth + SSL: ssl + SYNC: sync_async + tags: + - "4.2" + - sharded_cluster + - auth + - ssl + - sync_async + - name: test-4.2-sharded_cluster-noauth-ssl-sync + commands: + - func: run server + vars: + VERSION: "4.2" + TOPOLOGY: sharded_cluster + AUTH: noauth + SSL: ssl + - func: run tests + vars: + AUTH: noauth + SSL: ssl + SYNC: sync + TEST_NAME: default_sync + tags: + - "4.2" + - sharded_cluster + - noauth + - ssl + - sync + - name: test-4.2-sharded_cluster-noauth-ssl-async + commands: + - func: run server + vars: + VERSION: "4.2" + TOPOLOGY: sharded_cluster + AUTH: noauth + SSL: ssl + - func: run tests + vars: + AUTH: noauth + SSL: ssl + SYNC: async + TEST_NAME: default_async + tags: + - "4.2" + - sharded_cluster + - noauth + - ssl + - async + - name: test-4.2-sharded_cluster-noauth-ssl-sync_async + commands: + - func: run server + vars: + VERSION: "4.2" + TOPOLOGY: sharded_cluster + AUTH: noauth + SSL: ssl + - func: run tests + vars: + AUTH: noauth + SSL: ssl + SYNC: sync_async + tags: + - "4.2" + - sharded_cluster + - noauth + - ssl + - sync_async + - name: test-4.2-sharded_cluster-noauth-nossl-sync + commands: + - func: run server + vars: + VERSION: "4.2" + TOPOLOGY: sharded_cluster + AUTH: noauth + SSL: nossl + - func: run tests + vars: + AUTH: noauth + SSL: nossl + SYNC: sync + TEST_NAME: default_sync + tags: + - "4.2" + - sharded_cluster + - noauth + - nossl + - sync + - name: test-4.2-sharded_cluster-noauth-nossl-async + commands: + - func: run server + vars: + VERSION: "4.2" + TOPOLOGY: sharded_cluster + AUTH: noauth + SSL: nossl + - func: run tests + vars: + AUTH: noauth + SSL: nossl + SYNC: async + TEST_NAME: default_async + tags: + - "4.2" + - sharded_cluster + - noauth + - nossl + - async + - name: test-4.2-sharded_cluster-noauth-nossl-sync_async + commands: + - func: run server + vars: + VERSION: "4.2" + TOPOLOGY: sharded_cluster + AUTH: noauth + SSL: nossl + - func: run tests + vars: + AUTH: noauth + SSL: nossl + SYNC: sync_async + tags: + - "4.2" + - sharded_cluster + - noauth + - nossl + - sync_async - name: test-4.4-sharded_cluster-auth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.4" TOPOLOGY: sharded_cluster @@ -3120,7 +6804,7 @@ tasks: AUTH: auth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "4.4" - sharded_cluster @@ -3129,7 +6813,7 @@ tasks: - sync - name: test-4.4-sharded_cluster-auth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.4" TOPOLOGY: sharded_cluster @@ -3140,7 +6824,7 @@ tasks: AUTH: auth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "4.4" - sharded_cluster @@ -3149,7 +6833,7 @@ tasks: - async - name: test-4.4-sharded_cluster-auth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.4" TOPOLOGY: sharded_cluster @@ -3160,7 +6844,6 @@ tasks: AUTH: auth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - "4.4" - sharded_cluster @@ -3169,7 +6852,7 @@ tasks: - sync_async - name: test-4.4-sharded_cluster-noauth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.4" TOPOLOGY: sharded_cluster @@ -3180,7 +6863,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "4.4" - sharded_cluster @@ -3189,7 +6872,7 @@ tasks: - sync - name: test-4.4-sharded_cluster-noauth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.4" TOPOLOGY: sharded_cluster @@ -3200,7 +6883,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "4.4" - sharded_cluster @@ -3209,7 +6892,7 @@ tasks: - async - name: test-4.4-sharded_cluster-noauth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.4" TOPOLOGY: sharded_cluster @@ -3220,7 +6903,6 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - "4.4" - sharded_cluster @@ -3229,7 +6911,7 @@ tasks: - sync_async - name: test-4.4-sharded_cluster-noauth-nossl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.4" TOPOLOGY: sharded_cluster @@ -3240,7 +6922,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "4.4" - sharded_cluster @@ -3249,7 +6931,7 @@ tasks: - sync - name: test-4.4-sharded_cluster-noauth-nossl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.4" TOPOLOGY: sharded_cluster @@ -3260,7 +6942,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "4.4" - sharded_cluster @@ -3269,7 +6951,7 @@ tasks: - async - name: test-4.4-sharded_cluster-noauth-nossl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "4.4" TOPOLOGY: sharded_cluster @@ -3280,7 +6962,6 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync_async - TEST_SUITES: "" tags: - "4.4" - sharded_cluster @@ -3289,7 +6970,7 @@ tasks: - sync_async - name: test-5.0-sharded_cluster-auth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "5.0" TOPOLOGY: sharded_cluster @@ -3300,7 +6981,7 @@ tasks: AUTH: auth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "5.0" - sharded_cluster @@ -3309,7 +6990,7 @@ tasks: - sync - name: test-5.0-sharded_cluster-auth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "5.0" TOPOLOGY: sharded_cluster @@ -3320,7 +7001,7 @@ tasks: AUTH: auth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "5.0" - sharded_cluster @@ -3329,7 +7010,7 @@ tasks: - async - name: test-5.0-sharded_cluster-auth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "5.0" TOPOLOGY: sharded_cluster @@ -3340,7 +7021,6 @@ tasks: AUTH: auth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - "5.0" - sharded_cluster @@ -3349,7 +7029,7 @@ tasks: - sync_async - name: test-5.0-sharded_cluster-noauth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "5.0" TOPOLOGY: sharded_cluster @@ -3360,7 +7040,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "5.0" - sharded_cluster @@ -3369,7 +7049,7 @@ tasks: - sync - name: test-5.0-sharded_cluster-noauth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "5.0" TOPOLOGY: sharded_cluster @@ -3380,7 +7060,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "5.0" - sharded_cluster @@ -3389,7 +7069,7 @@ tasks: - async - name: test-5.0-sharded_cluster-noauth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "5.0" TOPOLOGY: sharded_cluster @@ -3400,7 +7080,6 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - "5.0" - sharded_cluster @@ -3409,7 +7088,7 @@ tasks: - sync_async - name: test-5.0-sharded_cluster-noauth-nossl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "5.0" TOPOLOGY: sharded_cluster @@ -3420,7 +7099,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "5.0" - sharded_cluster @@ -3429,7 +7108,7 @@ tasks: - sync - name: test-5.0-sharded_cluster-noauth-nossl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "5.0" TOPOLOGY: sharded_cluster @@ -3440,7 +7119,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "5.0" - sharded_cluster @@ -3449,7 +7128,7 @@ tasks: - async - name: test-5.0-sharded_cluster-noauth-nossl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "5.0" TOPOLOGY: sharded_cluster @@ -3460,7 +7139,6 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync_async - TEST_SUITES: "" tags: - "5.0" - sharded_cluster @@ -3469,7 +7147,7 @@ tasks: - sync_async - name: test-6.0-sharded_cluster-auth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "6.0" TOPOLOGY: sharded_cluster @@ -3480,7 +7158,7 @@ tasks: AUTH: auth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "6.0" - sharded_cluster @@ -3489,7 +7167,7 @@ tasks: - sync - name: test-6.0-sharded_cluster-auth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "6.0" TOPOLOGY: sharded_cluster @@ -3500,7 +7178,7 @@ tasks: AUTH: auth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "6.0" - sharded_cluster @@ -3509,7 +7187,7 @@ tasks: - async - name: test-6.0-sharded_cluster-auth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "6.0" TOPOLOGY: sharded_cluster @@ -3520,7 +7198,6 @@ tasks: AUTH: auth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - "6.0" - sharded_cluster @@ -3529,7 +7206,7 @@ tasks: - sync_async - name: test-6.0-sharded_cluster-noauth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "6.0" TOPOLOGY: sharded_cluster @@ -3540,7 +7217,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "6.0" - sharded_cluster @@ -3549,7 +7226,7 @@ tasks: - sync - name: test-6.0-sharded_cluster-noauth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "6.0" TOPOLOGY: sharded_cluster @@ -3560,7 +7237,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "6.0" - sharded_cluster @@ -3569,7 +7246,7 @@ tasks: - async - name: test-6.0-sharded_cluster-noauth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "6.0" TOPOLOGY: sharded_cluster @@ -3580,7 +7257,6 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - "6.0" - sharded_cluster @@ -3589,7 +7265,7 @@ tasks: - sync_async - name: test-6.0-sharded_cluster-noauth-nossl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "6.0" TOPOLOGY: sharded_cluster @@ -3600,7 +7276,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "6.0" - sharded_cluster @@ -3609,7 +7285,7 @@ tasks: - sync - name: test-6.0-sharded_cluster-noauth-nossl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "6.0" TOPOLOGY: sharded_cluster @@ -3620,7 +7296,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "6.0" - sharded_cluster @@ -3629,7 +7305,7 @@ tasks: - async - name: test-6.0-sharded_cluster-noauth-nossl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "6.0" TOPOLOGY: sharded_cluster @@ -3640,7 +7316,6 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync_async - TEST_SUITES: "" tags: - "6.0" - sharded_cluster @@ -3649,7 +7324,7 @@ tasks: - sync_async - name: test-7.0-sharded_cluster-auth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "7.0" TOPOLOGY: sharded_cluster @@ -3660,7 +7335,7 @@ tasks: AUTH: auth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "7.0" - sharded_cluster @@ -3669,7 +7344,7 @@ tasks: - sync - name: test-7.0-sharded_cluster-auth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "7.0" TOPOLOGY: sharded_cluster @@ -3680,7 +7355,7 @@ tasks: AUTH: auth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "7.0" - sharded_cluster @@ -3689,7 +7364,7 @@ tasks: - async - name: test-7.0-sharded_cluster-auth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "7.0" TOPOLOGY: sharded_cluster @@ -3700,7 +7375,6 @@ tasks: AUTH: auth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - "7.0" - sharded_cluster @@ -3709,7 +7383,7 @@ tasks: - sync_async - name: test-7.0-sharded_cluster-noauth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "7.0" TOPOLOGY: sharded_cluster @@ -3720,7 +7394,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "7.0" - sharded_cluster @@ -3729,7 +7403,7 @@ tasks: - sync - name: test-7.0-sharded_cluster-noauth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "7.0" TOPOLOGY: sharded_cluster @@ -3740,7 +7414,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "7.0" - sharded_cluster @@ -3749,7 +7423,7 @@ tasks: - async - name: test-7.0-sharded_cluster-noauth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "7.0" TOPOLOGY: sharded_cluster @@ -3760,7 +7434,6 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - "7.0" - sharded_cluster @@ -3769,7 +7442,7 @@ tasks: - sync_async - name: test-7.0-sharded_cluster-noauth-nossl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "7.0" TOPOLOGY: sharded_cluster @@ -3780,7 +7453,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "7.0" - sharded_cluster @@ -3789,7 +7462,7 @@ tasks: - sync - name: test-7.0-sharded_cluster-noauth-nossl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "7.0" TOPOLOGY: sharded_cluster @@ -3800,7 +7473,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "7.0" - sharded_cluster @@ -3809,7 +7482,7 @@ tasks: - async - name: test-7.0-sharded_cluster-noauth-nossl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "7.0" TOPOLOGY: sharded_cluster @@ -3820,7 +7493,6 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync_async - TEST_SUITES: "" tags: - "7.0" - sharded_cluster @@ -3829,7 +7501,7 @@ tasks: - sync_async - name: test-8.0-sharded_cluster-auth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "8.0" TOPOLOGY: sharded_cluster @@ -3840,7 +7512,7 @@ tasks: AUTH: auth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "8.0" - sharded_cluster @@ -3849,7 +7521,7 @@ tasks: - sync - name: test-8.0-sharded_cluster-auth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "8.0" TOPOLOGY: sharded_cluster @@ -3860,7 +7532,7 @@ tasks: AUTH: auth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "8.0" - sharded_cluster @@ -3869,7 +7541,7 @@ tasks: - async - name: test-8.0-sharded_cluster-auth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "8.0" TOPOLOGY: sharded_cluster @@ -3880,7 +7552,6 @@ tasks: AUTH: auth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - "8.0" - sharded_cluster @@ -3889,7 +7560,7 @@ tasks: - sync_async - name: test-8.0-sharded_cluster-noauth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "8.0" TOPOLOGY: sharded_cluster @@ -3900,7 +7571,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "8.0" - sharded_cluster @@ -3909,7 +7580,7 @@ tasks: - sync - name: test-8.0-sharded_cluster-noauth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "8.0" TOPOLOGY: sharded_cluster @@ -3920,7 +7591,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "8.0" - sharded_cluster @@ -3929,7 +7600,7 @@ tasks: - async - name: test-8.0-sharded_cluster-noauth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "8.0" TOPOLOGY: sharded_cluster @@ -3940,7 +7611,6 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - "8.0" - sharded_cluster @@ -3949,7 +7619,7 @@ tasks: - sync_async - name: test-8.0-sharded_cluster-noauth-nossl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "8.0" TOPOLOGY: sharded_cluster @@ -3960,7 +7630,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - "8.0" - sharded_cluster @@ -3969,7 +7639,7 @@ tasks: - sync - name: test-8.0-sharded_cluster-noauth-nossl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "8.0" TOPOLOGY: sharded_cluster @@ -3980,7 +7650,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - "8.0" - sharded_cluster @@ -3989,7 +7659,7 @@ tasks: - async - name: test-8.0-sharded_cluster-noauth-nossl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: "8.0" TOPOLOGY: sharded_cluster @@ -4000,7 +7670,6 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync_async - TEST_SUITES: "" tags: - "8.0" - sharded_cluster @@ -4009,7 +7678,7 @@ tasks: - sync_async - name: test-rapid-sharded_cluster-auth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: rapid TOPOLOGY: sharded_cluster @@ -4020,7 +7689,7 @@ tasks: AUTH: auth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - rapid - sharded_cluster @@ -4029,7 +7698,7 @@ tasks: - sync - name: test-rapid-sharded_cluster-auth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: rapid TOPOLOGY: sharded_cluster @@ -4040,7 +7709,7 @@ tasks: AUTH: auth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - rapid - sharded_cluster @@ -4049,7 +7718,7 @@ tasks: - async - name: test-rapid-sharded_cluster-auth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: rapid TOPOLOGY: sharded_cluster @@ -4060,7 +7729,6 @@ tasks: AUTH: auth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - rapid - sharded_cluster @@ -4069,7 +7737,7 @@ tasks: - sync_async - name: test-rapid-sharded_cluster-noauth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: rapid TOPOLOGY: sharded_cluster @@ -4080,7 +7748,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - rapid - sharded_cluster @@ -4089,7 +7757,7 @@ tasks: - sync - name: test-rapid-sharded_cluster-noauth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: rapid TOPOLOGY: sharded_cluster @@ -4100,7 +7768,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - rapid - sharded_cluster @@ -4109,7 +7777,7 @@ tasks: - async - name: test-rapid-sharded_cluster-noauth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: rapid TOPOLOGY: sharded_cluster @@ -4120,7 +7788,6 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - rapid - sharded_cluster @@ -4129,7 +7796,7 @@ tasks: - sync_async - name: test-rapid-sharded_cluster-noauth-nossl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: rapid TOPOLOGY: sharded_cluster @@ -4140,7 +7807,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - rapid - sharded_cluster @@ -4149,7 +7816,7 @@ tasks: - sync - name: test-rapid-sharded_cluster-noauth-nossl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: rapid TOPOLOGY: sharded_cluster @@ -4160,7 +7827,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - rapid - sharded_cluster @@ -4169,7 +7836,7 @@ tasks: - async - name: test-rapid-sharded_cluster-noauth-nossl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: rapid TOPOLOGY: sharded_cluster @@ -4180,7 +7847,6 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync_async - TEST_SUITES: "" tags: - rapid - sharded_cluster @@ -4189,7 +7855,7 @@ tasks: - sync_async - name: test-latest-sharded_cluster-auth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: latest TOPOLOGY: sharded_cluster @@ -4200,7 +7866,7 @@ tasks: AUTH: auth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - latest - sharded_cluster @@ -4209,7 +7875,7 @@ tasks: - sync - name: test-latest-sharded_cluster-auth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: latest TOPOLOGY: sharded_cluster @@ -4220,7 +7886,7 @@ tasks: AUTH: auth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - latest - sharded_cluster @@ -4229,7 +7895,7 @@ tasks: - async - name: test-latest-sharded_cluster-auth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: latest TOPOLOGY: sharded_cluster @@ -4240,7 +7906,6 @@ tasks: AUTH: auth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - latest - sharded_cluster @@ -4249,7 +7914,7 @@ tasks: - sync_async - name: test-latest-sharded_cluster-noauth-ssl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: latest TOPOLOGY: sharded_cluster @@ -4260,7 +7925,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - latest - sharded_cluster @@ -4269,7 +7934,7 @@ tasks: - sync - name: test-latest-sharded_cluster-noauth-ssl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: latest TOPOLOGY: sharded_cluster @@ -4280,7 +7945,7 @@ tasks: AUTH: noauth SSL: ssl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - latest - sharded_cluster @@ -4289,7 +7954,7 @@ tasks: - async - name: test-latest-sharded_cluster-noauth-ssl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: latest TOPOLOGY: sharded_cluster @@ -4300,7 +7965,6 @@ tasks: AUTH: noauth SSL: ssl SYNC: sync_async - TEST_SUITES: "" tags: - latest - sharded_cluster @@ -4309,7 +7973,7 @@ tasks: - sync_async - name: test-latest-sharded_cluster-noauth-nossl-sync commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: latest TOPOLOGY: sharded_cluster @@ -4320,7 +7984,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync - TEST_SUITES: default + TEST_NAME: default_sync tags: - latest - sharded_cluster @@ -4329,7 +7993,7 @@ tasks: - sync - name: test-latest-sharded_cluster-noauth-nossl-async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: latest TOPOLOGY: sharded_cluster @@ -4340,7 +8004,7 @@ tasks: AUTH: noauth SSL: nossl SYNC: async - TEST_SUITES: default_async + TEST_NAME: default_async tags: - latest - sharded_cluster @@ -4349,7 +8013,7 @@ tasks: - async - name: test-latest-sharded_cluster-noauth-nossl-sync_async commands: - - func: bootstrap mongo-orchestration + - func: run server vars: VERSION: latest TOPOLOGY: sharded_cluster @@ -4360,10 +8024,19 @@ tasks: AUTH: noauth SSL: nossl SYNC: sync_async - TEST_SUITES: "" tags: - latest - sharded_cluster - noauth - nossl - sync_async + + # Serverless tests + - name: test-serverless + commands: + - func: run tests + vars: + TEST_NAME: serverless + AUTH: auth + SSL: ssl + tags: [serverless] diff --git a/.evergreen/generated_configs/variants.yml b/.evergreen/generated_configs/variants.yml index 79c9b22c93..b4ff40eccb 100644 --- a/.evergreen/generated_configs/variants.yml +++ b/.evergreen/generated_configs/variants.yml @@ -45,11 +45,22 @@ buildvariants: batchtime: 10080 expansions: NO_EXT: "1" + - name: other-hosts-amazon2023 + tasks: + - name: .latest !.sync_async .sharded_cluster .auth .ssl + - name: .latest !.sync_async .replica_set .noauth .ssl + - name: .latest !.sync_async .standalone .noauth .nossl + display_name: Other hosts Amazon2023 + run_on: + - amazon2023-arm64-latest-large-m8g + batchtime: 10080 + expansions: + NO_EXT: "1" # Atlas connect tests - name: atlas-connect-rhel8-python3.9 tasks: - - name: atlas-connect + - name: .atlas_connect display_name: Atlas connect RHEL8 Python3.9 run_on: - rhel87-small @@ -57,7 +68,7 @@ buildvariants: PYTHON_BINARY: /opt/python/3.9/bin/python3 - name: atlas-connect-rhel8-python3.13 tasks: - - name: atlas-connect + - name: .atlas_connect display_name: Atlas connect RHEL8 Python3.13 run_on: - rhel87-small @@ -65,55 +76,27 @@ buildvariants: PYTHON_BINARY: /opt/python/3.13/bin/python3 # Atlas data lake tests - - name: atlas-data-lake-ubuntu-22-python3.9-auth-no-c + - name: atlas-data-lake-ubuntu-22-python3.9 tasks: - - name: atlas-data-lake-tests - display_name: Atlas Data Lake Ubuntu-22 Python3.9 Auth No C + - name: .atlas_data_lake + display_name: Atlas Data Lake Ubuntu-22 Python3.9 run_on: - ubuntu2204-small expansions: - AUTH: auth - NO_EXT: "1" PYTHON_BINARY: /opt/python/3.9/bin/python3 - - name: atlas-data-lake-ubuntu-22-python3.9-auth + - name: atlas-data-lake-ubuntu-22-python3.13 tasks: - - name: atlas-data-lake-tests - display_name: Atlas Data Lake Ubuntu-22 Python3.9 Auth + - name: .atlas_data_lake + display_name: Atlas Data Lake Ubuntu-22 Python3.13 run_on: - ubuntu2204-small expansions: - AUTH: auth - PYTHON_BINARY: /opt/python/3.9/bin/python3 - - name: atlas-data-lake-ubuntu-22-python3.13-auth-no-c - tasks: - - name: atlas-data-lake-tests - display_name: Atlas Data Lake Ubuntu-22 Python3.13 Auth No C - run_on: - - ubuntu2204-small - expansions: - AUTH: auth - NO_EXT: "1" - PYTHON_BINARY: /opt/python/3.13/bin/python3 - - name: atlas-data-lake-ubuntu-22-python3.13-auth - tasks: - - name: atlas-data-lake-tests - display_name: Atlas Data Lake Ubuntu-22 Python3.13 Auth - run_on: - - ubuntu2204-small - expansions: - AUTH: auth PYTHON_BINARY: /opt/python/3.13/bin/python3 # Aws auth tests - name: auth-aws-ubuntu-20-python3.9 tasks: - - name: aws-auth-test-4.4 - - name: aws-auth-test-5.0 - - name: aws-auth-test-6.0 - - name: aws-auth-test-7.0 - - name: aws-auth-test-8.0 - - name: aws-auth-test-rapid - - name: aws-auth-test-latest + - name: .auth-aws display_name: Auth AWS Ubuntu-20 Python3.9 run_on: - ubuntu2004-small @@ -121,13 +104,7 @@ buildvariants: PYTHON_BINARY: /opt/python/3.9/bin/python3 - name: auth-aws-ubuntu-20-python3.13 tasks: - - name: aws-auth-test-4.4 - - name: aws-auth-test-5.0 - - name: aws-auth-test-6.0 - - name: aws-auth-test-7.0 - - name: aws-auth-test-8.0 - - name: aws-auth-test-rapid - - name: aws-auth-test-latest + - name: .auth-aws display_name: Auth AWS Ubuntu-20 Python3.13 run_on: - ubuntu2004-small @@ -135,154 +112,86 @@ buildvariants: PYTHON_BINARY: /opt/python/3.13/bin/python3 - name: auth-aws-win64-python3.9 tasks: - - name: aws-auth-test-4.4 - - name: aws-auth-test-5.0 - - name: aws-auth-test-6.0 - - name: aws-auth-test-7.0 - - name: aws-auth-test-8.0 - - name: aws-auth-test-rapid - - name: aws-auth-test-latest + - name: .auth-aws !.auth-aws-ecs display_name: Auth AWS Win64 Python3.9 run_on: - windows-64-vsMulti-small expansions: - skip_ECS_auth_test: "true" PYTHON_BINARY: C:/python/Python39/python.exe - name: auth-aws-win64-python3.13 tasks: - - name: aws-auth-test-4.4 - - name: aws-auth-test-5.0 - - name: aws-auth-test-6.0 - - name: aws-auth-test-7.0 - - name: aws-auth-test-8.0 - - name: aws-auth-test-rapid - - name: aws-auth-test-latest + - name: .auth-aws !.auth-aws-ecs display_name: Auth AWS Win64 Python3.13 run_on: - windows-64-vsMulti-small expansions: - skip_ECS_auth_test: "true" PYTHON_BINARY: C:/python/Python313/python.exe - name: auth-aws-macos-python3.9 tasks: - - name: aws-auth-test-4.4 - - name: aws-auth-test-5.0 - - name: aws-auth-test-6.0 - - name: aws-auth-test-7.0 - - name: aws-auth-test-8.0 - - name: aws-auth-test-rapid - - name: aws-auth-test-latest + - name: .auth-aws !.auth-aws-web-identity !.auth-aws-ecs !.auth-aws-ec2 display_name: Auth AWS macOS Python3.9 run_on: - macos-14 expansions: - skip_ECS_auth_test: "true" - skip_EC2_auth_test: "true" - skip_web_identity_auth_test: "true" PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.9/bin/python3 - name: auth-aws-macos-python3.13 tasks: - - name: aws-auth-test-4.4 - - name: aws-auth-test-5.0 - - name: aws-auth-test-6.0 - - name: aws-auth-test-7.0 - - name: aws-auth-test-8.0 - - name: aws-auth-test-rapid - - name: aws-auth-test-latest + - name: .auth-aws !.auth-aws-web-identity !.auth-aws-ecs !.auth-aws-ec2 display_name: Auth AWS macOS Python3.13 run_on: - macos-14 expansions: - skip_ECS_auth_test: "true" - skip_EC2_auth_test: "true" - skip_web_identity_auth_test: "true" PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.13/bin/python3 - # Compression tests - - name: compression-snappy-rhel8-python3.9-no-c - tasks: - - name: .standalone .noauth .nossl .sync_async - display_name: Compression snappy RHEL8 Python3.9 No C - run_on: - - rhel87-small - expansions: - COMPRESSORS: snappy - NO_EXT: "1" - PYTHON_BINARY: /opt/python/3.9/bin/python3 - - name: compression-snappy-rhel8-python3.10 + # Aws lambda tests + - name: faas-lambda tasks: - - name: .standalone .noauth .nossl .sync_async - display_name: Compression snappy RHEL8 Python3.10 + - name: .aws_lambda + display_name: FaaS Lambda run_on: - rhel87-small - expansions: - COMPRESSORS: snappy - PYTHON_BINARY: /opt/python/3.10/bin/python3 - - name: compression-zlib-rhel8-python3.11-no-c - tasks: - - name: .standalone .noauth .nossl .sync_async - display_name: Compression zlib RHEL8 Python3.11 No C - run_on: - - rhel87-small - expansions: - COMPRESSORS: zlib - NO_EXT: "1" - PYTHON_BINARY: /opt/python/3.11/bin/python3 - - name: compression-zlib-rhel8-python3.12 - tasks: - - name: .standalone .noauth .nossl .sync_async - display_name: Compression zlib RHEL8 Python3.12 - run_on: - - rhel87-small - expansions: - COMPRESSORS: zlib - PYTHON_BINARY: /opt/python/3.12/bin/python3 - - name: compression-zstd-rhel8-python3.13-no-c + + # Backport pr tests + - name: backport-pr tasks: - - name: .standalone .noauth .nossl .sync_async !.4.0 - display_name: Compression zstd RHEL8 Python3.13 No C + - name: backport-pr + display_name: Backport PR run_on: - rhel87-small - expansions: - COMPRESSORS: zstd - NO_EXT: "1" - PYTHON_BINARY: /opt/python/3.13/bin/python3 - - name: compression-zstd-rhel8-python3.9 + + # Compression tests + - name: compression-snappy-rhel8 tasks: - - name: .standalone .noauth .nossl .sync_async !.4.0 - display_name: Compression zstd RHEL8 Python3.9 + - name: .compression + display_name: Compression snappy RHEL8 run_on: - rhel87-small expansions: - COMPRESSORS: zstd - PYTHON_BINARY: /opt/python/3.9/bin/python3 - - name: compression-snappy-rhel8-pypy3.10 + COMPRESSOR: snappy + - name: compression-zlib-rhel8 tasks: - - name: .standalone .noauth .nossl .sync_async - display_name: Compression snappy RHEL8 PyPy3.10 + - name: .compression + display_name: Compression zlib RHEL8 run_on: - rhel87-small expansions: - COMPRESSORS: snappy - PYTHON_BINARY: /opt/python/pypy3.10/bin/python3 - - name: compression-zlib-rhel8-pypy3.10 + COMPRESSOR: zlib + - name: compression-zstd-rhel8 tasks: - - name: .standalone .noauth .nossl .sync_async - display_name: Compression zlib RHEL8 PyPy3.10 + - name: .compression !.4.0 + display_name: Compression zstd RHEL8 run_on: - rhel87-small expansions: - COMPRESSORS: zlib - PYTHON_BINARY: /opt/python/pypy3.10/bin/python3 - - name: compression-zstd-rhel8-pypy3.10 + COMPRESSOR: zstd + + # Coverage report tests + - name: coverage-report tasks: - - name: .standalone .noauth .nossl .sync_async !.4.0 - display_name: Compression zstd RHEL8 PyPy3.10 + - name: coverage-report + display_name: Coverage Report run_on: - rhel87-small - expansions: - COMPRESSORS: zstd - PYTHON_BINARY: /opt/python/pypy3.10/bin/python3 # Disable test commands tests - name: disable-test-commands-rhel8-python3.9 @@ -300,7 +209,7 @@ buildvariants: # Doctests tests - name: doctests-rhel8-python3.9 tasks: - - name: doctests + - name: .doctests display_name: Doctests RHEL8 Python3.9 run_on: - rhel87-small @@ -318,7 +227,7 @@ buildvariants: - rhel87-small batchtime: 10080 expansions: - test_encryption: "true" + TEST_NAME: encryption PYTHON_BINARY: /opt/python/3.9/bin/python3 tags: [encryption_tag] - name: encryption-rhel8-python3.13 @@ -331,7 +240,7 @@ buildvariants: - rhel87-small batchtime: 10080 expansions: - test_encryption: "true" + TEST_NAME: encryption PYTHON_BINARY: /opt/python/3.13/bin/python3 tags: [encryption_tag] - name: encryption-rhel8-pypy3.10 @@ -344,7 +253,7 @@ buildvariants: - rhel87-small batchtime: 10080 expansions: - test_encryption: "true" + TEST_NAME: encryption PYTHON_BINARY: /opt/python/pypy3.10/bin/python3 tags: [encryption_tag] - name: encryption-crypt_shared-rhel8-python3.9 @@ -357,8 +266,8 @@ buildvariants: - rhel87-small batchtime: 10080 expansions: - test_encryption: "true" - test_crypt_shared: "true" + TEST_NAME: encryption + TEST_CRYPT_SHARED: "true" PYTHON_BINARY: /opt/python/3.9/bin/python3 tags: [encryption_tag] - name: encryption-crypt_shared-rhel8-python3.13 @@ -371,8 +280,8 @@ buildvariants: - rhel87-small batchtime: 10080 expansions: - test_encryption: "true" - test_crypt_shared: "true" + TEST_NAME: encryption + TEST_CRYPT_SHARED: "true" PYTHON_BINARY: /opt/python/3.13/bin/python3 tags: [encryption_tag] - name: encryption-crypt_shared-rhel8-pypy3.10 @@ -385,50 +294,50 @@ buildvariants: - rhel87-small batchtime: 10080 expansions: - test_encryption: "true" - test_crypt_shared: "true" + TEST_NAME: encryption + TEST_CRYPT_SHARED: "true" PYTHON_BINARY: /opt/python/pypy3.10/bin/python3 tags: [encryption_tag] - name: encryption-pyopenssl-rhel8-python3.9 tasks: - - name: .sharded_cluster .auth .ssl .sync_async - - name: .replica_set .noauth .ssl .sync_async - - name: .standalone .noauth .nossl .sync_async + - name: .sharded_cluster .auth .ssl .sync + - name: .replica_set .noauth .ssl .sync + - name: .standalone .noauth .nossl .sync display_name: Encryption PyOpenSSL RHEL8 Python3.9 run_on: - rhel87-small batchtime: 10080 expansions: - test_encryption: "true" - test_encryption_pyopenssl: "true" + TEST_NAME: encryption + SUB_TEST_NAME: pyopenssl PYTHON_BINARY: /opt/python/3.9/bin/python3 tags: [encryption_tag] - name: encryption-pyopenssl-rhel8-python3.13 tasks: - - name: .sharded_cluster .auth .ssl .sync_async - - name: .replica_set .noauth .ssl .sync_async - - name: .standalone .noauth .nossl .sync_async + - name: .sharded_cluster .auth .ssl .sync + - name: .replica_set .noauth .ssl .sync + - name: .standalone .noauth .nossl .sync display_name: Encryption PyOpenSSL RHEL8 Python3.13 run_on: - rhel87-small batchtime: 10080 expansions: - test_encryption: "true" - test_encryption_pyopenssl: "true" + TEST_NAME: encryption + SUB_TEST_NAME: pyopenssl PYTHON_BINARY: /opt/python/3.13/bin/python3 tags: [encryption_tag] - name: encryption-pyopenssl-rhel8-pypy3.10 tasks: - - name: .sharded_cluster .auth .ssl .sync_async - - name: .replica_set .noauth .ssl .sync_async - - name: .standalone .noauth .nossl .sync_async + - name: .sharded_cluster .auth .ssl .sync + - name: .replica_set .noauth .ssl .sync + - name: .standalone .noauth .nossl .sync display_name: Encryption PyOpenSSL RHEL8 PyPy3.10 run_on: - rhel87-small batchtime: 10080 expansions: - test_encryption: "true" - test_encryption_pyopenssl: "true" + TEST_NAME: encryption + SUB_TEST_NAME: pyopenssl PYTHON_BINARY: /opt/python/pypy3.10/bin/python3 tags: [encryption_tag] - name: encryption-rhel8-python3.10 @@ -438,7 +347,7 @@ buildvariants: run_on: - rhel87-small expansions: - test_encryption: "true" + TEST_NAME: encryption PYTHON_BINARY: /opt/python/3.10/bin/python3 - name: encryption-crypt_shared-rhel8-python3.11 tasks: @@ -447,18 +356,17 @@ buildvariants: run_on: - rhel87-small expansions: - test_encryption: "true" - test_crypt_shared: "true" + TEST_NAME: encryption + TEST_CRYPT_SHARED: "true" PYTHON_BINARY: /opt/python/3.11/bin/python3 - - name: encryption-pyopenssl-rhel8-python3.12 + - name: encryption-rhel8-python3.12 tasks: - name: .standalone .noauth .nossl .sync_async - display_name: Encryption PyOpenSSL RHEL8 Python3.12 + display_name: Encryption RHEL8 Python3.12 run_on: - rhel87-small expansions: - test_encryption: "true" - test_encryption_pyopenssl: "true" + TEST_NAME: encryption PYTHON_BINARY: /opt/python/3.12/bin/python3 - name: encryption-macos-python3.9 tasks: @@ -468,7 +376,7 @@ buildvariants: - macos-14 batchtime: 10080 expansions: - test_encryption: "true" + TEST_NAME: encryption PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.9/bin/python3 tags: [encryption_tag] - name: encryption-macos-python3.13 @@ -479,7 +387,7 @@ buildvariants: - macos-14 batchtime: 10080 expansions: - test_encryption: "true" + TEST_NAME: encryption PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.13/bin/python3 tags: [encryption_tag] - name: encryption-crypt_shared-macos-python3.9 @@ -490,8 +398,8 @@ buildvariants: - macos-14 batchtime: 10080 expansions: - test_encryption: "true" - test_crypt_shared: "true" + TEST_NAME: encryption + TEST_CRYPT_SHARED: "true" PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.9/bin/python3 tags: [encryption_tag] - name: encryption-crypt_shared-macos-python3.13 @@ -502,8 +410,8 @@ buildvariants: - macos-14 batchtime: 10080 expansions: - test_encryption: "true" - test_crypt_shared: "true" + TEST_NAME: encryption + TEST_CRYPT_SHARED: "true" PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.13/bin/python3 tags: [encryption_tag] - name: encryption-win64-python3.9 @@ -514,7 +422,7 @@ buildvariants: - windows-64-vsMulti-small batchtime: 10080 expansions: - test_encryption: "true" + TEST_NAME: encryption PYTHON_BINARY: C:/python/Python39/python.exe tags: [encryption_tag] - name: encryption-win64-python3.13 @@ -525,7 +433,7 @@ buildvariants: - windows-64-vsMulti-small batchtime: 10080 expansions: - test_encryption: "true" + TEST_NAME: encryption PYTHON_BINARY: C:/python/Python313/python.exe tags: [encryption_tag] - name: encryption-crypt_shared-win64-python3.9 @@ -536,8 +444,8 @@ buildvariants: - windows-64-vsMulti-small batchtime: 10080 expansions: - test_encryption: "true" - test_crypt_shared: "true" + TEST_NAME: encryption + TEST_CRYPT_SHARED: "true" PYTHON_BINARY: C:/python/Python39/python.exe tags: [encryption_tag] - name: encryption-crypt_shared-win64-python3.13 @@ -548,66 +456,30 @@ buildvariants: - windows-64-vsMulti-small batchtime: 10080 expansions: - test_encryption: "true" - test_crypt_shared: "true" + TEST_NAME: encryption + TEST_CRYPT_SHARED: "true" PYTHON_BINARY: C:/python/Python313/python.exe tags: [encryption_tag] # Enterprise auth tests - - name: auth-enterprise-macos-python3.9-auth + - name: auth-enterprise-macos tasks: - - name: test-enterprise-auth - display_name: Auth Enterprise macOS Python3.9 Auth + - name: .enterprise_auth !.pypy + display_name: Auth Enterprise macOS run_on: - macos-14 - expansions: - AUTH: auth - PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.9/bin/python3 - - name: auth-enterprise-rhel8-python3.10-auth - tasks: - - name: test-enterprise-auth - display_name: Auth Enterprise RHEL8 Python3.10 Auth - run_on: - - rhel87-small - expansions: - AUTH: auth - PYTHON_BINARY: /opt/python/3.10/bin/python3 - - name: auth-enterprise-rhel8-python3.11-auth - tasks: - - name: test-enterprise-auth - display_name: Auth Enterprise RHEL8 Python3.11 Auth - run_on: - - rhel87-small - expansions: - AUTH: auth - PYTHON_BINARY: /opt/python/3.11/bin/python3 - - name: auth-enterprise-rhel8-python3.12-auth - tasks: - - name: test-enterprise-auth - display_name: Auth Enterprise RHEL8 Python3.12 Auth - run_on: - - rhel87-small - expansions: - AUTH: auth - PYTHON_BINARY: /opt/python/3.12/bin/python3 - - name: auth-enterprise-win64-python3.13-auth + - name: auth-enterprise-win64 tasks: - - name: test-enterprise-auth - display_name: Auth Enterprise Win64 Python3.13 Auth + - name: .enterprise_auth !.pypy + display_name: Auth Enterprise Win64 run_on: - windows-64-vsMulti-small - expansions: - AUTH: auth - PYTHON_BINARY: C:/python/Python313/python.exe - - name: auth-enterprise-rhel8-pypy3.10-auth + - name: auth-enterprise-rhel8 tasks: - - name: test-enterprise-auth - display_name: Auth Enterprise RHEL8 PyPy3.10 Auth + - name: .enterprise_auth + display_name: Auth Enterprise RHEL8 run_on: - rhel87-small - expansions: - AUTH: auth - PYTHON_BINARY: /opt/python/pypy3.10/bin/python3 # Free threaded tests - name: free-threaded-rhel8-python3.13t @@ -681,62 +553,40 @@ buildvariants: SSL: ssl PYTHON_BINARY: /opt/python/3.13/bin/python3 - # Load balancer tests - - name: load-balancer-rhel8-v6.0-python3.9 - tasks: - - name: .load-balancer - display_name: Load Balancer RHEL8 v6.0 Python3.9 - run_on: - - rhel87-small - batchtime: 10080 - expansions: - VERSION: "6.0" - PYTHON_BINARY: /opt/python/3.9/bin/python3 - - name: load-balancer-rhel8-v7.0-python3.9 - tasks: - - name: .load-balancer - display_name: Load Balancer RHEL8 v7.0 Python3.9 - run_on: - - rhel87-small - batchtime: 10080 - expansions: - VERSION: "7.0" - PYTHON_BINARY: /opt/python/3.9/bin/python3 - - name: load-balancer-rhel8-v8.0-python3.9 - tasks: - - name: .load-balancer - display_name: Load Balancer RHEL8 v8.0 Python3.9 - run_on: - - rhel87-small - batchtime: 10080 - expansions: - VERSION: "8.0" - PYTHON_BINARY: /opt/python/3.9/bin/python3 - - name: load-balancer-rhel8-rapid-python3.9 + # Import time tests + - name: import-time tasks: - - name: .load-balancer - display_name: Load Balancer RHEL8 rapid Python3.9 + - name: check-import-time + display_name: Import Time run_on: - rhel87-small - batchtime: 10080 - expansions: - VERSION: rapid - PYTHON_BINARY: /opt/python/3.9/bin/python3 - - name: load-balancer-rhel8-latest-python3.9 + + # Kms tests + - name: kms + tasks: + - name: test-gcpkms + batchtime: 10080 + - name: test-gcpkms-fail + - name: test-azurekms + batchtime: 10080 + - name: test-azurekms-fail + display_name: KMS + run_on: + - debian11-small + + # Load balancer tests + - name: load-balancer tasks: - name: .load-balancer - display_name: Load Balancer RHEL8 latest Python3.9 + display_name: Load Balancer run_on: - rhel87-small batchtime: 10080 - expansions: - VERSION: latest - PYTHON_BINARY: /opt/python/3.9/bin/python3 # Mockupdb tests - name: mockupdb-rhel8-python3.9 tasks: - - name: mockupdb + - name: .mockupdb display_name: MockupDB RHEL8 Python3.9 run_on: - rhel87-small @@ -746,10 +596,7 @@ buildvariants: # Mod wsgi tests - name: mod_wsgi-ubuntu-22-python3.9 tasks: - - name: mod-wsgi-standalone - - name: mod-wsgi-replica-set - - name: mod-wsgi-embedded-mode-standalone - - name: mod-wsgi-embedded-mode-replica-set + - name: .mod_wsgi display_name: mod_wsgi Ubuntu-22 Python3.9 run_on: - ubuntu2204-small @@ -758,10 +605,7 @@ buildvariants: PYTHON_BINARY: /opt/python/3.9/bin/python3 - name: mod_wsgi-ubuntu-22-python3.13 tasks: - - name: mod-wsgi-standalone - - name: mod-wsgi-replica-set - - name: mod-wsgi-embedded-mode-standalone - - name: mod-wsgi-embedded-mode-replica-set + - name: .mod_wsgi display_name: mod_wsgi Ubuntu-22 Python3.13 run_on: - ubuntu2204-small @@ -772,7 +616,7 @@ buildvariants: # No c ext tests - name: no-c-ext-rhel8-python3.9 tasks: - - name: .standalone .noauth .nossl .sync_async + - name: .standalone .noauth .nossl !.sync_async display_name: No C Ext RHEL8 Python3.9 run_on: - rhel87-small @@ -781,7 +625,7 @@ buildvariants: PYTHON_BINARY: /opt/python/3.9/bin/python3 - name: no-c-ext-rhel8-python3.10 tasks: - - name: .replica_set .noauth .nossl .sync_async + - name: .replica_set .noauth .nossl !.sync_async display_name: No C Ext RHEL8 Python3.10 run_on: - rhel87-small @@ -790,7 +634,7 @@ buildvariants: PYTHON_BINARY: /opt/python/3.10/bin/python3 - name: no-c-ext-rhel8-python3.11 tasks: - - name: .sharded_cluster .noauth .nossl .sync_async + - name: .sharded_cluster .noauth .nossl !.sync_async display_name: No C Ext RHEL8 Python3.11 run_on: - rhel87-small @@ -799,7 +643,7 @@ buildvariants: PYTHON_BINARY: /opt/python/3.11/bin/python3 - name: no-c-ext-rhel8-python3.12 tasks: - - name: .standalone .noauth .nossl .sync_async + - name: .standalone .noauth .nossl !.sync_async display_name: No C Ext RHEL8 Python3.12 run_on: - rhel87-small @@ -808,7 +652,7 @@ buildvariants: PYTHON_BINARY: /opt/python/3.12/bin/python3 - name: no-c-ext-rhel8-python3.13 tasks: - - name: .replica_set .noauth .nossl .sync_async + - name: .replica_set .noauth .nossl !.sync_async display_name: No C Ext RHEL8 Python3.13 run_on: - rhel87-small @@ -816,188 +660,83 @@ buildvariants: NO_EXT: "1" PYTHON_BINARY: /opt/python/3.13/bin/python3 - # Ocsp tests - - name: ocsp-rhel8-v4.4-python3.9 - tasks: - - name: .ocsp - display_name: OCSP RHEL8 v4.4 Python3.9 - run_on: - - rhel87-small - batchtime: 20160 - expansions: - AUTH: noauth - SSL: ssl - TOPOLOGY: server - VERSION: "4.4" - PYTHON_BINARY: /opt/python/3.9/bin/python3 - - name: ocsp-rhel8-v5.0-python3.10 + # No server tests + - name: no-server tasks: - - name: .ocsp - display_name: OCSP RHEL8 v5.0 Python3.10 - run_on: - - rhel87-small - batchtime: 20160 - expansions: - AUTH: noauth - SSL: ssl - TOPOLOGY: server - VERSION: "5.0" - PYTHON_BINARY: /opt/python/3.10/bin/python3 - - name: ocsp-rhel8-v6.0-python3.11 - tasks: - - name: .ocsp - display_name: OCSP RHEL8 v6.0 Python3.11 + - name: .no-server + display_name: No server run_on: - rhel87-small - batchtime: 20160 - expansions: - AUTH: noauth - SSL: ssl - TOPOLOGY: server - VERSION: "6.0" - PYTHON_BINARY: /opt/python/3.11/bin/python3 - - name: ocsp-rhel8-v7.0-python3.12 - tasks: - - name: .ocsp - display_name: OCSP RHEL8 v7.0 Python3.12 - run_on: - - rhel87-small - batchtime: 20160 - expansions: - AUTH: noauth - SSL: ssl - TOPOLOGY: server - VERSION: "7.0" - PYTHON_BINARY: /opt/python/3.12/bin/python3 - - name: ocsp-rhel8-v8.0-python3.13 - tasks: - - name: .ocsp - display_name: OCSP RHEL8 v8.0 Python3.13 - run_on: - - rhel87-small - batchtime: 20160 - expansions: - AUTH: noauth - SSL: ssl - TOPOLOGY: server - VERSION: "8.0" - PYTHON_BINARY: /opt/python/3.13/bin/python3 - - name: ocsp-rhel8-rapid-pypy3.10 - tasks: - - name: .ocsp - display_name: OCSP RHEL8 rapid PyPy3.10 - run_on: - - rhel87-small - batchtime: 20160 - expansions: - AUTH: noauth - SSL: ssl - TOPOLOGY: server - VERSION: rapid - PYTHON_BINARY: /opt/python/pypy3.10/bin/python3 - - name: ocsp-rhel8-latest-python3.9 + + # Ocsp tests + - name: ocsp-rhel8 tasks: - name: .ocsp - display_name: OCSP RHEL8 latest Python3.9 + display_name: OCSP RHEL8 run_on: - rhel87-small - batchtime: 20160 - expansions: - AUTH: noauth - SSL: ssl - TOPOLOGY: server - VERSION: latest - PYTHON_BINARY: /opt/python/3.9/bin/python3 - - name: ocsp-win64-v4.4-python3.9 - tasks: - - name: .ocsp-rsa !.ocsp-staple - display_name: OCSP Win64 v4.4 Python3.9 - run_on: - - windows-64-vsMulti-small - batchtime: 20160 - expansions: - AUTH: noauth - SSL: ssl - TOPOLOGY: server - VERSION: "4.4" - PYTHON_BINARY: C:/python/Python39/python.exe - - name: ocsp-win64-v8.0-python3.13 + batchtime: 10080 + - name: ocsp-win64 tasks: - - name: .ocsp-rsa !.ocsp-staple - display_name: OCSP Win64 v8.0 Python3.13 + - name: .ocsp-rsa !.ocsp-staple .latest + - name: .ocsp-rsa !.ocsp-staple .4.4 + display_name: OCSP Win64 run_on: - windows-64-vsMulti-small - batchtime: 20160 - expansions: - AUTH: noauth - SSL: ssl - TOPOLOGY: server - VERSION: "8.0" - PYTHON_BINARY: C:/python/Python313/python.exe - - name: ocsp-macos-v4.4-python3.9 - tasks: - - name: .ocsp-rsa !.ocsp-staple - display_name: OCSP macOS v4.4 Python3.9 - run_on: - - macos-14 - batchtime: 20160 - expansions: - AUTH: noauth - SSL: ssl - TOPOLOGY: server - VERSION: "4.4" - PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.9/bin/python3 - - name: ocsp-macos-v8.0-python3.13 + batchtime: 10080 + - name: ocsp-macos tasks: - - name: .ocsp-rsa !.ocsp-staple - display_name: OCSP macOS v8.0 Python3.13 + - name: .ocsp-rsa !.ocsp-staple .latest + - name: .ocsp-rsa !.ocsp-staple .4.4 + display_name: OCSP macOS run_on: - macos-14 - batchtime: 20160 - expansions: - AUTH: noauth - SSL: ssl - TOPOLOGY: server - VERSION: "8.0" - PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.13/bin/python3 + batchtime: 10080 # Oidc auth tests - name: auth-oidc-ubuntu-22 tasks: - - name: testoidc_task_group - - name: testazureoidc_task_group - - name: testgcpoidc_task_group - - name: testk8soidc_task_group + - name: .auth_oidc display_name: Auth OIDC Ubuntu-22 run_on: - ubuntu2204-small - batchtime: 20160 + batchtime: 10080 - name: auth-oidc-macos tasks: - - name: testoidc_task_group + - name: .auth_oidc !.auth_oidc_remote display_name: Auth OIDC macOS run_on: - macos-14 - batchtime: 20160 + batchtime: 10080 - name: auth-oidc-win64 tasks: - - name: testoidc_task_group + - name: .auth_oidc !.auth_oidc_remote display_name: Auth OIDC Win64 run_on: - windows-64-vsMulti-small - batchtime: 20160 + batchtime: 10080 + + # Perf tests + - name: performance-benchmarks + tasks: + - name: .perf + display_name: Performance Benchmarks + run_on: + - rhel90-dbx-perf-large + batchtime: 10080 # Pyopenssl tests - name: pyopenssl-macos-python3.9 tasks: - - name: .replica_set .noauth .nossl .sync_async - - name: .7.0 .noauth .nossl .sync_async + - name: .replica_set .noauth .nossl .sync + - name: .7.0 .noauth .nossl .sync display_name: PyOpenSSL macOS Python3.9 run_on: - macos-14 batchtime: 10080 expansions: - test_pyopenssl: "true" + TEST_NAME: default + SUB_TEST_NAME: pyopenssl PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.9/bin/python3 - name: pyopenssl-rhel8-python3.10 tasks: @@ -1008,29 +747,32 @@ buildvariants: - rhel87-small batchtime: 10080 expansions: - test_pyopenssl: "true" + TEST_NAME: default + SUB_TEST_NAME: pyopenssl PYTHON_BINARY: /opt/python/3.10/bin/python3 - name: pyopenssl-rhel8-python3.11 tasks: - - name: .replica_set .auth .ssl .sync_async - - name: .7.0 .auth .ssl .sync_async + - name: .replica_set .auth .ssl .sync + - name: .7.0 .auth .ssl .sync display_name: PyOpenSSL RHEL8 Python3.11 run_on: - rhel87-small batchtime: 10080 expansions: - test_pyopenssl: "true" + TEST_NAME: default + SUB_TEST_NAME: pyopenssl PYTHON_BINARY: /opt/python/3.11/bin/python3 - name: pyopenssl-rhel8-python3.12 tasks: - - name: .replica_set .auth .ssl .sync_async - - name: .7.0 .auth .ssl .sync_async + - name: .replica_set .auth .ssl .sync + - name: .7.0 .auth .ssl .sync display_name: PyOpenSSL RHEL8 Python3.12 run_on: - rhel87-small batchtime: 10080 expansions: - test_pyopenssl: "true" + TEST_NAME: default + SUB_TEST_NAME: pyopenssl PYTHON_BINARY: /opt/python/3.12/bin/python3 - name: pyopenssl-win64-python3.13 tasks: @@ -1041,24 +783,26 @@ buildvariants: - windows-64-vsMulti-small batchtime: 10080 expansions: - test_pyopenssl: "true" + TEST_NAME: default + SUB_TEST_NAME: pyopenssl PYTHON_BINARY: C:/python/Python313/python.exe - name: pyopenssl-rhel8-pypy3.10 tasks: - - name: .replica_set .auth .ssl .sync_async - - name: .7.0 .auth .ssl .sync_async + - name: .replica_set .auth .ssl .sync + - name: .7.0 .auth .ssl .sync display_name: PyOpenSSL RHEL8 PyPy3.10 run_on: - rhel87-small batchtime: 10080 expansions: - test_pyopenssl: "true" + TEST_NAME: default + SUB_TEST_NAME: pyopenssl PYTHON_BINARY: /opt/python/pypy3.10/bin/python3 # Search index tests - name: search-index-helpers-rhel8-python3.9 tasks: - - name: test_atlas_task_group_search_indexes + - name: .search_index display_name: Search Index Helpers RHEL8 Python3.9 run_on: - rhel87-small @@ -1066,6 +810,19 @@ buildvariants: PYTHON_BINARY: /opt/python/3.9/bin/python3 # Server tests + - name: test-rhel8-python3.9-cov-no-c + tasks: + - name: .standalone .sync_async + - name: .replica_set .sync_async + - name: .sharded_cluster .sync_async + display_name: "* Test RHEL8 Python3.9 cov No C" + run_on: + - rhel87-small + expansions: + COVERAGE: coverage + NO_EXT: "1" + PYTHON_BINARY: /opt/python/3.9/bin/python3 + tags: [coverage_tag] - name: test-rhel8-python3.9-cov tasks: - name: .standalone .sync_async @@ -1078,6 +835,19 @@ buildvariants: COVERAGE: coverage PYTHON_BINARY: /opt/python/3.9/bin/python3 tags: [coverage_tag] + - name: test-rhel8-python3.13-cov-no-c + tasks: + - name: .standalone .sync_async + - name: .replica_set .sync_async + - name: .sharded_cluster .sync_async + display_name: "* Test RHEL8 Python3.13 cov No C" + run_on: + - rhel87-small + expansions: + COVERAGE: coverage + NO_EXT: "1" + PYTHON_BINARY: /opt/python/3.13/bin/python3 + tags: [coverage_tag] - name: test-rhel8-python3.13-cov tasks: - name: .standalone .sync_async @@ -1090,6 +860,19 @@ buildvariants: COVERAGE: coverage PYTHON_BINARY: /opt/python/3.13/bin/python3 tags: [coverage_tag] + - name: test-rhel8-pypy3.10-cov-no-c + tasks: + - name: .standalone .sync_async + - name: .replica_set .sync_async + - name: .sharded_cluster .sync_async + display_name: "* Test RHEL8 PyPy3.10 cov No C" + run_on: + - rhel87-small + expansions: + COVERAGE: coverage + NO_EXT: "1" + PYTHON_BINARY: /opt/python/pypy3.10/bin/python3 + tags: [coverage_tag] - name: test-rhel8-pypy3.10-cov tasks: - name: .standalone .sync_async @@ -1243,38 +1026,32 @@ buildvariants: # Serverless tests - name: serverless-rhel8-python3.9 tasks: - - name: serverless_task_group + - name: .serverless display_name: Serverless RHEL8 Python3.9 run_on: - rhel87-small batchtime: 10080 expansions: - test_serverless: "true" - AUTH: auth - SSL: ssl PYTHON_BINARY: /opt/python/3.9/bin/python3 - name: serverless-rhel8-python3.13 tasks: - - name: serverless_task_group + - name: .serverless display_name: Serverless RHEL8 Python3.13 run_on: - rhel87-small batchtime: 10080 expansions: - test_serverless: "true" - AUTH: auth - SSL: ssl PYTHON_BINARY: /opt/python/3.13/bin/python3 # Stable api tests - name: stable-api-require-v1-rhel8-python3.9-auth tasks: - - name: .standalone .5.0 .noauth .nossl .sync_async - - name: .standalone .6.0 .noauth .nossl .sync_async - - name: .standalone .7.0 .noauth .nossl .sync_async - - name: .standalone .8.0 .noauth .nossl .sync_async - - name: .standalone .rapid .noauth .nossl .sync_async - - name: .standalone .latest .noauth .nossl .sync_async + - name: "!.replica_set .5.0 .noauth .nossl .sync_async" + - name: "!.replica_set .6.0 .noauth .nossl .sync_async" + - name: "!.replica_set .7.0 .noauth .nossl .sync_async" + - name: "!.replica_set .8.0 .noauth .nossl .sync_async" + - name: "!.replica_set .rapid .noauth .nossl .sync_async" + - name: "!.replica_set .latest .noauth .nossl .sync_async" display_name: Stable API require v1 RHEL8 Python3.9 Auth run_on: - rhel87-small @@ -1302,12 +1079,12 @@ buildvariants: tags: [versionedApi_tag] - name: stable-api-require-v1-rhel8-python3.13-auth tasks: - - name: .standalone .5.0 .noauth .nossl .sync_async - - name: .standalone .6.0 .noauth .nossl .sync_async - - name: .standalone .7.0 .noauth .nossl .sync_async - - name: .standalone .8.0 .noauth .nossl .sync_async - - name: .standalone .rapid .noauth .nossl .sync_async - - name: .standalone .latest .noauth .nossl .sync_async + - name: "!.replica_set .5.0 .noauth .nossl .sync_async" + - name: "!.replica_set .6.0 .noauth .nossl .sync_async" + - name: "!.replica_set .7.0 .noauth .nossl .sync_async" + - name: "!.replica_set .8.0 .noauth .nossl .sync_async" + - name: "!.replica_set .rapid .noauth .nossl .sync_async" + - name: "!.replica_set .latest .noauth .nossl .sync_async" display_name: Stable API require v1 RHEL8 Python3.13 Auth run_on: - rhel87-small @@ -1338,6 +1115,7 @@ buildvariants: - name: storage-inmemory-rhel8-python3.9 tasks: - name: .standalone .noauth .nossl .4.0 .sync_async + - name: .standalone .noauth .nossl .4.2 .sync_async - name: .standalone .noauth .nossl .4.4 .sync_async - name: .standalone .noauth .nossl .5.0 .sync_async - name: .standalone .noauth .nossl .6.0 .sync_async diff --git a/.evergreen/resync-specs.sh b/.evergreen/resync-specs.sh index dca116c2d3..1f70940aa0 100755 --- a/.evergreen/resync-specs.sh +++ b/.evergreen/resync-specs.sh @@ -1,6 +1,6 @@ #!/bin/bash -# exit when any command fails -set -e +# Resync test files from the specifications repo. +set -eu PYMONGO=$(dirname "$(cd "$(dirname "$0")"; pwd)") SPECS=${MDB_SPECS:-~/Work/specifications} diff --git a/.evergreen/run-azurekms-fail-test.sh b/.evergreen/run-azurekms-fail-test.sh deleted file mode 100755 index d1117dcb32..0000000000 --- a/.evergreen/run-azurekms-fail-test.sh +++ /dev/null @@ -1,12 +0,0 @@ -#!/bin/bash -set -o errexit # Exit the script with error if any of the commands fail -HERE=$(dirname ${BASH_SOURCE:-$0}) -. $DRIVERS_TOOLS/.evergreen/csfle/azurekms/setup-secrets.sh -export LIBMONGOCRYPT_URL=https://s3.amazonaws.com/mciuploads/libmongocrypt/debian11/master/latest/libmongocrypt.tar.gz -SKIP_SERVERS=1 bash $HERE/setup-encryption.sh -PYTHON_BINARY=/opt/mongodbtoolchain/v4/bin/python3 \ - KEY_NAME="${AZUREKMS_KEYNAME}" \ - KEY_VAULT_ENDPOINT="${AZUREKMS_KEYVAULTENDPOINT}" \ - SUCCESS=false TEST_FLE_AZURE_AUTO=1 \ - $HERE/just.sh test-eg -bash $HERE/teardown-encryption.sh diff --git a/.evergreen/run-azurekms-test.sh b/.evergreen/run-azurekms-test.sh deleted file mode 100755 index 28a84a52e2..0000000000 --- a/.evergreen/run-azurekms-test.sh +++ /dev/null @@ -1,29 +0,0 @@ -#!/bin/bash -set -o errexit # Exit the script with error if any of the commands fail -HERE=$(dirname ${BASH_SOURCE:-$0}) -source ${DRIVERS_TOOLS}/.evergreen/csfle/azurekms/secrets-export.sh -echo "Copying files ... begin" -export AZUREKMS_RESOURCEGROUP=${AZUREKMS_RESOURCEGROUP} -export AZUREKMS_VMNAME=${AZUREKMS_VMNAME} -export AZUREKMS_PRIVATEKEYPATH=/tmp/testazurekms_privatekey -export LIBMONGOCRYPT_URL=https://s3.amazonaws.com/mciuploads/libmongocrypt/debian11/master/latest/libmongocrypt.tar.gz -SKIP_SERVERS=1 bash $HERE/setup-encryption.sh -# Set up the remote files to test. -git add . -git commit -m "add files" || true -git archive -o /tmp/mongo-python-driver.tar HEAD -tar -rf /tmp/mongo-python-driver.tar libmongocrypt -gzip -f /tmp/mongo-python-driver.tar -# shellcheck disable=SC2088 -AZUREKMS_SRC="/tmp/mongo-python-driver.tar.gz" AZUREKMS_DST="~/" \ - $DRIVERS_TOOLS/.evergreen/csfle/azurekms/copy-file.sh -echo "Copying files ... end" -echo "Untarring file ... begin" -AZUREKMS_CMD="tar xf mongo-python-driver.tar.gz" \ - $DRIVERS_TOOLS/.evergreen/csfle/azurekms/run-command.sh -echo "Untarring file ... end" -echo "Running test ... begin" -AZUREKMS_CMD="KEY_NAME=\"$AZUREKMS_KEYNAME\" KEY_VAULT_ENDPOINT=\"$AZUREKMS_KEYVAULTENDPOINT\" SUCCESS=true TEST_FLE_AZURE_AUTO=1 bash ./.evergreen/just.sh test-eg" \ - $DRIVERS_TOOLS/.evergreen/csfle/azurekms/run-command.sh -echo "Running test ... end" -bash $HERE/teardown-encryption.sh diff --git a/.evergreen/run-deployed-lambda-aws-tests.sh b/.evergreen/run-deployed-lambda-aws-tests.sh deleted file mode 100755 index aa16d62650..0000000000 --- a/.evergreen/run-deployed-lambda-aws-tests.sh +++ /dev/null @@ -1,10 +0,0 @@ -#!/bin/bash -set -o errexit # Exit the script with error if any of the commands fail - -export PATH="/opt/python/3.9/bin:${PATH}" -python --version -pushd ./test/lambda - -. build.sh -popd -. ${DRIVERS_TOOLS}/.evergreen/aws_lambda/run-deployed-lambda-aws-tests.sh diff --git a/.evergreen/run-gcpkms-test.sh b/.evergreen/run-gcpkms-test.sh deleted file mode 100755 index 37ec2bfe56..0000000000 --- a/.evergreen/run-gcpkms-test.sh +++ /dev/null @@ -1,27 +0,0 @@ -#!/bin/bash -set -o errexit # Exit the script with error if any of the commands fail -HERE=$(dirname ${BASH_SOURCE:-$0}) - -source ${DRIVERS_TOOLS}/.evergreen/csfle/gcpkms/secrets-export.sh -echo "Copying files ... begin" -export GCPKMS_GCLOUD=${GCPKMS_GCLOUD} -export GCPKMS_PROJECT=${GCPKMS_PROJECT} -export GCPKMS_ZONE=${GCPKMS_ZONE} -export GCPKMS_INSTANCENAME=${GCPKMS_INSTANCENAME} -export LIBMONGOCRYPT_URL=https://s3.amazonaws.com/mciuploads/libmongocrypt/debian11/master/latest/libmongocrypt.tar.gz -SKIP_SERVERS=1 bash $HERE/setup-encryption.sh -# Set up the remote files to test. -git add . -git commit -m "add files" || true -git archive -o /tmp/mongo-python-driver.tar HEAD -tar -rf /tmp/mongo-python-driver.tar libmongocrypt -gzip -f /tmp/mongo-python-driver.tar -GCPKMS_SRC=/tmp/mongo-python-driver.tar.gz GCPKMS_DST=$GCPKMS_INSTANCENAME: $DRIVERS_TOOLS/.evergreen/csfle/gcpkms/copy-file.sh -echo "Copying files ... end" -echo "Untarring file ... begin" -GCPKMS_CMD="tar xf mongo-python-driver.tar.gz" $DRIVERS_TOOLS/.evergreen/csfle/gcpkms/run-command.sh -echo "Untarring file ... end" -echo "Running test ... begin" -GCPKMS_CMD="SUCCESS=true TEST_FLE_GCP_AUTO=1 ./.evergreen/just.sh test-eg" $DRIVERS_TOOLS/.evergreen/csfle/gcpkms/run-command.sh -echo "Running test ... end" -bash $HERE/teardown-encryption.sh diff --git a/.evergreen/run-import-time-test.sh b/.evergreen/run-import-time-test.sh deleted file mode 100755 index 95e3c93d25..0000000000 --- a/.evergreen/run-import-time-test.sh +++ /dev/null @@ -1,33 +0,0 @@ -#!/bin/bash -ex - -set -o errexit # Exit the script with error if any of the commands fail -set -x - -. .evergreen/utils.sh - -if [ -z "${PYTHON_BINARY:-}" ]; then - PYTHON_BINARY=$(find_python3) -fi - -# Use the previous commit if this was not a PR run. -if [ "$BASE_SHA" == "$HEAD_SHA" ]; then - BASE_SHA=$(git rev-parse HEAD~1) -fi - -function get_import_time() { - local log_file - createvirtualenv "$PYTHON_BINARY" import-venv - python -m pip install -q ".[aws,encryption,gssapi,ocsp,snappy,zstd]" - # Import once to cache modules - python -c "import pymongo" - log_file="pymongo-$1.log" - python -X importtime -c "import pymongo" 2> $log_file -} - -get_import_time $HEAD_SHA -git stash || true -git checkout $BASE_SHA -get_import_time $BASE_SHA -git checkout $HEAD_SHA -git stash apply || true -python tools/compare_import_time.py $HEAD_SHA $BASE_SHA diff --git a/.evergreen/run-mongodb-aws-ecs-test.sh b/.evergreen/run-mongodb-aws-ecs-test.sh index 91777be226..c55c423e49 100755 --- a/.evergreen/run-mongodb-aws-ecs-test.sh +++ b/.evergreen/run-mongodb-aws-ecs-test.sh @@ -1,7 +1,6 @@ #!/bin/bash - -# Don't trace since the URI contains a password that shouldn't show up in the logs -set -o errexit # Exit the script with error if any of the commands fail +# Script run on an ECS host to test MONGODB-AWS. +set -eu ############################################ # Main Program # @@ -26,9 +25,9 @@ apt-get -qq update < /dev/null > /dev/null apt-get -qq install $PYTHON_VER $PYTHON_VER-venv build-essential $PYTHON_VER-dev -y < /dev/null > /dev/null export PYTHON_BINARY=$PYTHON_VER -export TEST_AUTH_AWS=1 -export AUTH="auth" export SET_XTRACE_ON=1 cd src rm -rf .venv -bash .evergreen/just.sh test-eg +rm -f .evergreen/scripts/test-env.sh || true +bash ./.evergreen/just.sh setup-tests auth_aws ecs-remote +bash .evergreen/just.sh run-tests diff --git a/.evergreen/run-mongodb-oidc-remote-test.sh b/.evergreen/run-mongodb-oidc-remote-test.sh deleted file mode 100755 index bb90bddf07..0000000000 --- a/.evergreen/run-mongodb-oidc-remote-test.sh +++ /dev/null @@ -1,60 +0,0 @@ -#!/bin/bash - -set +x # Disable debug trace -set -eu - -echo "Running MONGODB-OIDC remote tests" - -OIDC_ENV=${OIDC_ENV:-"test"} - -# Make sure DRIVERS_TOOLS is set. -if [ -z "$DRIVERS_TOOLS" ]; then - echo "Must specify DRIVERS_TOOLS" - exit 1 -fi - -# Set up the remote files to test. -git add . -git commit -m "add files" || true -export TEST_TAR_FILE=/tmp/mongo-python-driver.tgz -git archive -o $TEST_TAR_FILE HEAD - -pushd $DRIVERS_TOOLS - -if [ $OIDC_ENV == "test" ]; then - echo "Test OIDC environment does not support remote test!" - exit 1 - -elif [ $OIDC_ENV == "azure" ]; then - export AZUREOIDC_DRIVERS_TAR_FILE=$TEST_TAR_FILE - export AZUREOIDC_TEST_CMD="OIDC_ENV=azure ./.evergreen/run-mongodb-oidc-test.sh" - bash ./.evergreen/auth_oidc/azure/run-driver-test.sh - -elif [ $OIDC_ENV == "gcp" ]; then - export GCPOIDC_DRIVERS_TAR_FILE=$TEST_TAR_FILE - export GCPOIDC_TEST_CMD="OIDC_ENV=gcp ./.evergreen/run-mongodb-oidc-test.sh" - bash ./.evergreen/auth_oidc/gcp/run-driver-test.sh - -elif [ $OIDC_ENV == "k8s" ]; then - # Make sure K8S_VARIANT is set. - if [ -z "$K8S_VARIANT" ]; then - echo "Must specify K8S_VARIANT" - popd - exit 1 - fi - - bash ./.evergreen/auth_oidc/k8s/setup-pod.sh - bash ./.evergreen/auth_oidc/k8s/run-self-test.sh - export K8S_DRIVERS_TAR_FILE=$TEST_TAR_FILE - export K8S_TEST_CMD="OIDC_ENV=k8s ./.evergreen/run-mongodb-oidc-test.sh" - source ./.evergreen/auth_oidc/k8s/secrets-export.sh # for MONGODB_URI - bash ./.evergreen/auth_oidc/k8s/run-driver-test.sh - bash ./.evergreen/auth_oidc/k8s/teardown-pod.sh - -else - echo "Unrecognized OIDC_ENV $OIDC_ENV" - pod - exit 1 -fi - -popd diff --git a/.evergreen/run-mongodb-oidc-test.sh b/.evergreen/run-mongodb-oidc-test.sh index 46c4f24969..a60b112bcb 100755 --- a/.evergreen/run-mongodb-oidc-test.sh +++ b/.evergreen/run-mongodb-oidc-test.sh @@ -1,35 +1,15 @@ #!/bin/bash - -set +x # Disable debug trace +# Script run on a remote host to test MONGODB-OIDC. set -eu -echo "Running MONGODB-OIDC authentication tests" - -OIDC_ENV=${OIDC_ENV:-"test"} - -if [ $OIDC_ENV == "test" ]; then - # Make sure DRIVERS_TOOLS is set. - if [ -z "$DRIVERS_TOOLS" ]; then - echo "Must specify DRIVERS_TOOLS" - exit 1 - fi - source ${DRIVERS_TOOLS}/.evergreen/auth_oidc/secrets-export.sh - -elif [ $OIDC_ENV == "azure" ]; then - source ./env.sh - -elif [ $OIDC_ENV == "gcp" ]; then - source ./secrets-export.sh - -elif [ $OIDC_ENV == "k8s" ]; then - echo "Running oidc on k8s" +echo "Running MONGODB-OIDC authentication tests on ${OIDC_ENV}..." +if [ ${OIDC_ENV} == "k8s" ]; then + SUB_TEST_NAME=$K8S_VARIANT-remote else - echo "Unrecognized OIDC_ENV $OIDC_ENV" - exit 1 + SUB_TEST_NAME=$OIDC_ENV-remote fi +bash ./.evergreen/just.sh setup-tests auth_oidc $SUB_TEST_NAME +bash ./.evergreen/just.sh run-tests "${@:1}" -export TEST_AUTH_OIDC=1 -export COVERAGE=1 -export AUTH="auth" -bash ./.evergreen/just.sh test-eg "${@:1}" +echo "Running MONGODB-OIDC authentication tests on ${OIDC_ENV}... done." diff --git a/.evergreen/run-perf-tests.sh b/.evergreen/run-perf-tests.sh deleted file mode 100755 index e6a51b3297..0000000000 --- a/.evergreen/run-perf-tests.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash - -set -o xtrace -set -o errexit - -git clone --depth 1 https://github.com/mongodb/specifications.git -pushd specifications/source/benchmarking/data -tar xf extended_bson.tgz -tar xf parallel.tgz -tar xf single_and_multi_document.tgz -popd - -export TEST_PATH="${PROJECT_DIRECTORY}/specifications/source/benchmarking/data" -export OUTPUT_FILE="${PROJECT_DIRECTORY}/results.json" - -export PYTHON_BINARY=/opt/mongodbtoolchain/v4/bin/python3 -export PERF_TEST=1 - -bash ./.evergreen/just.sh test-eg diff --git a/.evergreen/run-tests.sh b/.evergreen/run-tests.sh index fbe310ad1e..2b7d856d41 100755 --- a/.evergreen/run-tests.sh +++ b/.evergreen/run-tests.sh @@ -1,294 +1,38 @@ #!/bin/bash -set -o errexit # Exit the script with error if any of the commands fail -set -o xtrace +# Run a test suite that was configured with setup-tests.sh. +set -eu -# Note: It is assumed that you have already set up a virtual environment before running this file. +SCRIPT_DIR=$(dirname ${BASH_SOURCE:-$0}) +SCRIPT_DIR="$( cd -- "$SCRIPT_DIR" > /dev/null 2>&1 && pwd )" +ROOT_DIR="$(dirname $SCRIPT_DIR)" -# Supported/used environment variables: -# AUTH Set to enable authentication. Defaults to "noauth" -# SSL Set to enable SSL. Defaults to "nossl" -# GREEN_FRAMEWORK The green framework to test with, if any. -# COVERAGE If non-empty, run the test suite with coverage. -# COMPRESSORS If non-empty, install appropriate compressor. -# LIBMONGOCRYPT_URL The URL to download libmongocrypt. -# TEST_DATA_LAKE If non-empty, run data lake tests. -# TEST_ENCRYPTION If non-empty, run encryption tests. -# TEST_CRYPT_SHARED If non-empty, install crypt_shared lib. -# TEST_SERVERLESS If non-empy, test on serverless. -# TEST_LOADBALANCER If non-empy, test load balancing. -# TEST_FLE_AZURE_AUTO If non-empy, test auto FLE on Azure -# TEST_FLE_GCP_AUTO If non-empy, test auto FLE on GCP -# TEST_PYOPENSSL If non-empy, test with PyOpenSSL -# TEST_ENTERPRISE_AUTH If non-empty, test with Enterprise Auth -# TEST_AUTH_AWS If non-empty, test AWS Auth Mechanism -# TEST_AUTH_OIDC If non-empty, test OIDC Auth Mechanism -# TEST_PERF If non-empty, run performance tests -# TEST_OCSP If non-empty, run OCSP tests -# TEST_ATLAS If non-empty, test Atlas connections -# TEST_INDEX_MANAGEMENT If non-empty, run index management tests -# TEST_ENCRYPTION_PYOPENSSL If non-empy, test encryption with PyOpenSSL +pushd $ROOT_DIR -AUTH=${AUTH:-noauth} -SSL=${SSL:-nossl} -TEST_SUITES=${TEST_SUITES:-} -TEST_ARGS="${*:1}" - -export PIP_QUIET=1 # Quiet by default -export PIP_PREFER_BINARY=1 # Prefer binary dists by default - -set +x -PYTHON_IMPL=$(uv run --frozen python -c "import platform; print(platform.python_implementation())") - -# Try to source local Drivers Secrets -if [ -f ./secrets-export.sh ]; then - echo "Sourcing secrets" - source ./secrets-export.sh +# Try to source the env file. +if [ -f $SCRIPT_DIR/scripts/env.sh ]; then + echo "Sourcing env inputs" + . $SCRIPT_DIR/scripts/env.sh else - echo "Not sourcing secrets" -fi - -# Start compiling the args we'll pass to uv. -# Run in an isolated environment so as not to pollute the base venv. -UV_ARGS=("--isolated --frozen --extra test") - -# Ensure C extensions if applicable. -if [ -z "${NO_EXT:-}" ] && [ "$PYTHON_IMPL" = "CPython" ]; then - uv run --frozen tools/fail_if_no_c.py -fi - -if [ "$AUTH" != "noauth" ]; then - if [ -n "$TEST_DATA_LAKE" ]; then - export DB_USER="mhuser" - export DB_PASSWORD="pencil" - elif [ -n "$TEST_SERVERLESS" ]; then - source "${DRIVERS_TOOLS}"/.evergreen/serverless/secrets-export.sh - export DB_USER=$SERVERLESS_ATLAS_USER - export DB_PASSWORD=$SERVERLESS_ATLAS_PASSWORD - export MONGODB_URI="$SERVERLESS_URI" - echo "MONGODB_URI=$MONGODB_URI" - export SINGLE_MONGOS_LB_URI=$MONGODB_URI - export MULTI_MONGOS_LB_URI=$MONGODB_URI - elif [ -n "$TEST_AUTH_OIDC" ]; then - export DB_USER=$OIDC_ADMIN_USER - export DB_PASSWORD=$OIDC_ADMIN_PWD - export DB_IP="$MONGODB_URI" - else - export DB_USER="bob" - export DB_PASSWORD="pwd123" - fi - echo "Added auth, DB_USER: $DB_USER" -fi - -if [ -n "$TEST_ENTERPRISE_AUTH" ]; then - UV_ARGS+=("--extra gssapi") - if [ "Windows_NT" = "$OS" ]; then - echo "Setting GSSAPI_PASS" - export GSSAPI_PASS=${SASL_PASS} - export GSSAPI_CANONICALIZE="true" - else - # BUILD-3830 - touch krb5.conf.empty - export KRB5_CONFIG=${PROJECT_DIRECTORY}/.evergreen/krb5.conf.empty - - echo "Writing keytab" - echo ${KEYTAB_BASE64} | base64 -d > ${PROJECT_DIRECTORY}/.evergreen/drivers.keytab - echo "Running kinit" - kinit -k -t ${PROJECT_DIRECTORY}/.evergreen/drivers.keytab -p ${PRINCIPAL} - fi - echo "Setting GSSAPI variables" - export GSSAPI_HOST=${SASL_HOST} - export GSSAPI_PORT=${SASL_PORT} - export GSSAPI_PRINCIPAL=${PRINCIPAL} - - export TEST_SUITES="auth" -fi - -if [ -n "$TEST_LOADBALANCER" ]; then - export LOAD_BALANCER=1 - export SINGLE_MONGOS_LB_URI="${SINGLE_MONGOS_LB_URI:-mongodb://127.0.0.1:8000/?loadBalanced=true}" - export MULTI_MONGOS_LB_URI="${MULTI_MONGOS_LB_URI:-mongodb://127.0.0.1:8001/?loadBalanced=true}" - export TEST_SUITES="load_balancer" -fi - -if [ "$SSL" != "nossl" ]; then - export CLIENT_PEM="$DRIVERS_TOOLS/.evergreen/x509gen/client.pem" - export CA_PEM="$DRIVERS_TOOLS/.evergreen/x509gen/ca.pem" - - if [ -n "$TEST_LOADBALANCER" ]; then - export SINGLE_MONGOS_LB_URI="${SINGLE_MONGOS_LB_URI}&tls=true" - export MULTI_MONGOS_LB_URI="${MULTI_MONGOS_LB_URI}&tls=true" - fi -fi - -if [ "$COMPRESSORS" = "snappy" ]; then - UV_ARGS+=("--extra snappy") -elif [ "$COMPRESSORS" = "zstd" ]; then - UV_ARGS+=("--extra zstandard") -fi - -# PyOpenSSL test setup. -if [ -n "$TEST_PYOPENSSL" ]; then - UV_ARGS+=("--extra ocsp") -fi - -if [ -n "$TEST_ENCRYPTION" ] || [ -n "$TEST_FLE_AZURE_AUTO" ] || [ -n "$TEST_FLE_GCP_AUTO" ]; then - # Check for libmongocrypt download. - if [ ! -d "libmongocrypt" ]; then - echo "Run encryption setup first!" - exit 1 - fi - - UV_ARGS+=("--extra encryption") - # TODO: Test with 'pip install pymongocrypt' - UV_ARGS+=("--group pymongocrypt_source") - - # Use the nocrypto build to avoid dependency issues with older windows/python versions. - BASE=$(pwd)/libmongocrypt/nocrypto - if [ -f "${BASE}/lib/libmongocrypt.so" ]; then - PYMONGOCRYPT_LIB=${BASE}/lib/libmongocrypt.so - elif [ -f "${BASE}/lib/libmongocrypt.dylib" ]; then - PYMONGOCRYPT_LIB=${BASE}/lib/libmongocrypt.dylib - elif [ -f "${BASE}/bin/mongocrypt.dll" ]; then - PYMONGOCRYPT_LIB=${BASE}/bin/mongocrypt.dll - # libmongocrypt's windows dll is not marked executable. - chmod +x $PYMONGOCRYPT_LIB - PYMONGOCRYPT_LIB=$(cygpath -m $PYMONGOCRYPT_LIB) - elif [ -f "${BASE}/lib64/libmongocrypt.so" ]; then - PYMONGOCRYPT_LIB=${BASE}/lib64/libmongocrypt.so - else - echo "Cannot find libmongocrypt shared object file" - exit 1 - fi - export PYMONGOCRYPT_LIB - # Ensure pymongocrypt is working properly. - # shellcheck disable=SC2048 - uv run ${UV_ARGS[*]} python -c "import pymongocrypt; print('pymongocrypt version: '+pymongocrypt.__version__)" - # shellcheck disable=SC2048 - uv run ${UV_ARGS[*]} python -c "import pymongocrypt; print('libmongocrypt version: '+pymongocrypt.libmongocrypt_version())" - # PATH is updated by configure-env.sh for access to mongocryptd. -fi - -if [ -n "$TEST_ENCRYPTION" ]; then - if [ -n "$TEST_ENCRYPTION_PYOPENSSL" ]; then - UV_ARGS+=("--extra ocsp") - fi - - if [ -n "$TEST_CRYPT_SHARED" ]; then - CRYPT_SHARED_DIR=`dirname $CRYPT_SHARED_LIB_PATH` - echo "using crypt_shared_dir $CRYPT_SHARED_DIR" - export DYLD_FALLBACK_LIBRARY_PATH=$CRYPT_SHARED_DIR:$DYLD_FALLBACK_LIBRARY_PATH - export LD_LIBRARY_PATH=$CRYPT_SHARED_DIR:$LD_LIBRARY_PATH - export PATH=$CRYPT_SHARED_DIR:$PATH - fi - # Only run the encryption tests. - TEST_SUITES="encryption" + echo "Not sourcing env inputs" fi -if [ -n "$TEST_FLE_AZURE_AUTO" ] || [ -n "$TEST_FLE_GCP_AUTO" ]; then - if [[ -z "$SUCCESS" ]]; then - echo "Must define SUCCESS" - exit 1 - fi - - if echo "$MONGODB_URI" | grep -q "@"; then - echo "MONGODB_URI unexpectedly contains user credentials in FLE test!"; - exit 1 - fi - TEST_SUITES="csfle" -fi - -if [ -n "$TEST_INDEX_MANAGEMENT" ]; then - source $DRIVERS_TOOLS/.evergreen/atlas/secrets-export.sh - export DB_USER="${DRIVERS_ATLAS_LAMBDA_USER}" - set +x - export DB_PASSWORD="${DRIVERS_ATLAS_LAMBDA_PASSWORD}" - set -x - TEST_SUITES="index_management" -fi - -if [ -n "$TEST_DATA_LAKE" ] && [ -z "$TEST_ARGS" ]; then - TEST_SUITES="data_lake" -fi - -if [ -n "$TEST_ATLAS" ]; then - TEST_SUITES="atlas" -fi - -if [ -n "$TEST_OCSP" ]; then - UV_ARGS+=("--extra ocsp") - TEST_SUITES="ocsp" -fi - -if [ -n "$TEST_AUTH_AWS" ]; then - UV_ARGS+=("--extra aws") - TEST_SUITES="auth_aws" -fi - -if [ -n "$TEST_AUTH_OIDC" ]; then - UV_ARGS+=("--extra aws") - TEST_SUITES="auth_oidc" -fi - -if [ -n "$PERF_TEST" ]; then - UV_ARGS+=("--group perf") - start_time=$(date +%s) - TEST_SUITES="perf" - # PYTHON-4769 Run perf_test.py directly otherwise pytest's test collection negatively - # affects the benchmark results. - TEST_ARGS="test/performance/perf_test.py $TEST_ARGS" -fi - -echo "Running $AUTH tests over $SSL with python $(uv python find)" -uv run --frozen python -c 'import sys; print(sys.version)' - - -# Run the tests, and store the results in Evergreen compatible XUnit XML -# files in the xunit-results/ directory. - -# Run the tests with coverage if requested and coverage is installed. -# Only cover CPython. PyPy reports suspiciously low coverage. -if [ -n "$COVERAGE" ] && [ "$PYTHON_IMPL" = "CPython" ]; then - # Keep in sync with combine-coverage.sh. - # coverage >=5 is needed for relative_files=true. - UV_ARGS+=("--group coverage") - TEST_ARGS="$TEST_ARGS --cov" -fi - -if [ -n "$GREEN_FRAMEWORK" ]; then - UV_ARGS+=("--group $GREEN_FRAMEWORK") -fi - -# Show the installed packages -# shellcheck disable=SC2048 -PIP_QUIET=0 uv run ${UV_ARGS[*]} --with pip pip list - -if [ -z "$GREEN_FRAMEWORK" ]; then - # Use --capture=tee-sys so pytest prints test output inline: - # https://docs.pytest.org/en/stable/how-to/capture-stdout-stderr.html - PYTEST_ARGS="-v --capture=tee-sys --durations=5 $TEST_ARGS" - if [ -n "$TEST_SUITES" ]; then - PYTEST_ARGS="-m $TEST_SUITES $PYTEST_ARGS" - fi - # shellcheck disable=SC2048 - uv run ${UV_ARGS[*]} pytest $PYTEST_ARGS +# Handle test inputs. +if [ -f $SCRIPT_DIR/scripts/test-env.sh ]; then + echo "Sourcing test inputs" + . $SCRIPT_DIR/scripts/test-env.sh else - # shellcheck disable=SC2048 - uv run ${UV_ARGS[*]} green_framework_test.py $GREEN_FRAMEWORK -v $TEST_ARGS + echo "Missing test inputs, please run 'just setup-tests'" + exit 1 fi -# Handle perf test post actions. -if [ -n "$PERF_TEST" ]; then - end_time=$(date +%s) - elapsed_secs=$((end_time-start_time)) +# List the packages. +uv sync ${UV_ARGS} --reinstall +uv pip list - cat results.json +# Ensure we go back to base environment after the test. +trap "uv sync" EXIT HUP - echo "{\"failures\": 0, \"results\": [{\"status\": \"pass\", \"exit_code\": 0, \"test_file\": \"BenchMarkTests\", \"start\": $start_time, \"end\": $end_time, \"elapsed\": $elapsed_secs}]}" > report.json +# Start the test runner. +uv run ${UV_ARGS} .evergreen/scripts/run_tests.py "$@" - cat report.json -fi - -# Handle coverage post actions. -if [ -n "$COVERAGE" ]; then - rm -rf .pytest_cache -fi +popd diff --git a/.evergreen/scripts/__init__.py b/.evergreen/scripts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/.evergreen/scripts/archive-mongodb-logs.sh b/.evergreen/scripts/archive-mongodb-logs.sh deleted file mode 100755 index 70a337cd11..0000000000 --- a/.evergreen/scripts/archive-mongodb-logs.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash - -set -o xtrace -mkdir out_dir -# shellcheck disable=SC2156 -find "$MONGO_ORCHESTRATION_HOME" -name \*.log -exec sh -c 'x="{}"; mv $x $PWD/out_dir/$(basename $(dirname $x))_$(basename $x)' \; -tar zcvf mongodb-logs.tar.gz -C out_dir/ . -rm -rf out_dir diff --git a/.evergreen/scripts/bootstrap-mongo-orchestration.sh b/.evergreen/scripts/bootstrap-mongo-orchestration.sh deleted file mode 100755 index 1d2b145de8..0000000000 --- a/.evergreen/scripts/bootstrap-mongo-orchestration.sh +++ /dev/null @@ -1,46 +0,0 @@ -#!/bin/bash - -set -o xtrace - -# Enable core dumps if enabled on the machine -# Copied from https://github.com/mongodb/mongo/blob/master/etc/evergreen.yml -if [ -f /proc/self/coredump_filter ]; then - # Set the shell process (and its children processes) to dump ELF headers (bit 4), - # anonymous shared mappings (bit 1), and anonymous private mappings (bit 0). - echo 0x13 >/proc/self/coredump_filter - - if [ -f /sbin/sysctl ]; then - # Check that the core pattern is set explicitly on our distro image instead - # of being the OS's default value. This ensures that coredump names are consistent - # across distros and can be picked up by Evergreen. - core_pattern=$(/sbin/sysctl -n "kernel.core_pattern") - if [ "$core_pattern" = "dump_%e.%p.core" ]; then - echo "Enabling coredumps" - ulimit -c unlimited - fi - fi -fi - -if [ "$(uname -s)" = "Darwin" ]; then - core_pattern_mac=$(/usr/sbin/sysctl -n "kern.corefile") - if [ "$core_pattern_mac" = "dump_%N.%P.core" ]; then - echo "Enabling coredumps" - ulimit -c unlimited - fi -fi - -if [ -n "${skip_crypt_shared}" ]; then - export SKIP_CRYPT_SHARED=1 -fi - -MONGODB_VERSION=${VERSION} \ - TOPOLOGY=${TOPOLOGY} \ - AUTH=${AUTH:-noauth} \ - SSL=${SSL:-nossl} \ - STORAGE_ENGINE=${STORAGE_ENGINE:-} \ - DISABLE_TEST_COMMANDS=${DISABLE_TEST_COMMANDS:-} \ - ORCHESTRATION_FILE=${ORCHESTRATION_FILE:-} \ - REQUIRE_API_VERSION=${REQUIRE_API_VERSION:-} \ - LOAD_BALANCER=${LOAD_BALANCER:-} \ - bash ${DRIVERS_TOOLS}/.evergreen/run-orchestration.sh -# run-orchestration generates expansion file with the MONGODB_URI for the cluster diff --git a/.evergreen/scripts/check-import-time.sh b/.evergreen/scripts/check-import-time.sh index cdd2025d59..f7a1117b97 100755 --- a/.evergreen/scripts/check-import-time.sh +++ b/.evergreen/scripts/check-import-time.sh @@ -1,7 +1,43 @@ #!/bin/bash +# Check for regressions in the import time of pymongo. +set -eu -. .evergreen/scripts/env.sh -set -x -export BASE_SHA="$1" -export HEAD_SHA="$2" -bash .evergreen/run-import-time-test.sh +HERE=$(dirname ${BASH_SOURCE:-$0}) + +source $HERE/env.sh + +pushd $HERE/../.. >/dev/null + +BASE_SHA="$1" +HEAD_SHA="$2" + +. .evergreen/utils.sh + +if [ -z "${PYTHON_BINARY:-}" ]; then + PYTHON_BINARY=$(find_python3) +fi + +# Use the previous commit if this was not a PR run. +if [ "$BASE_SHA" == "$HEAD_SHA" ]; then + BASE_SHA=$(git rev-parse HEAD~1) +fi + +function get_import_time() { + local log_file + createvirtualenv "$PYTHON_BINARY" import-venv + python -m pip install -q ".[aws,encryption,gssapi,ocsp,snappy,zstd]" + # Import once to cache modules + python -c "import pymongo" + log_file="pymongo-$1.log" + python -X importtime -c "import pymongo" 2> $log_file +} + +get_import_time $HEAD_SHA +git stash || true +git checkout $BASE_SHA +get_import_time $BASE_SHA +git checkout $HEAD_SHA +git stash apply || true +python tools/compare_import_time.py $HEAD_SHA $BASE_SHA + +popd >/dev/null diff --git a/.evergreen/scripts/cleanup.sh b/.evergreen/scripts/cleanup.sh index a1fd92f04d..f04a936fd2 100755 --- a/.evergreen/scripts/cleanup.sh +++ b/.evergreen/scripts/cleanup.sh @@ -1,7 +1,14 @@ #!/bin/bash +# Clean up resources at the end of an evergreen run. +set -eu -if [ -f "$DRIVERS_TOOLS"/.evergreen/csfle/secrets-export.sh ]; then - bash .evergreen/teardown-encryption.sh +HERE=$(dirname ${BASH_SOURCE:-$0}) + +# Try to source the env file. +if [ -f $HERE/env.sh ]; then + echo "Sourcing env file" + source $HERE/env.sh fi + rm -rf "${DRIVERS_TOOLS}" || true -rm -f ./secrets-export.sh || true +rm -f $HERE/../../secrets-export.sh || true diff --git a/.evergreen/scripts/configure-env.sh b/.evergreen/scripts/configure-env.sh index cb018d09f0..81713f4191 100755 --- a/.evergreen/scripts/configure-env.sh +++ b/.evergreen/scripts/configure-env.sh @@ -1,5 +1,5 @@ #!/bin/bash - +# Configure an evergreen test environment. set -eu # Get the current unique version of this checkout @@ -16,6 +16,18 @@ DRIVERS_TOOLS="$(dirname $PROJECT_DIRECTORY)/drivers-tools" CARGO_HOME=${CARGO_HOME:-${DRIVERS_TOOLS}/.cargo} UV_TOOL_DIR=$PROJECT_DIRECTORY/.local/uv/tools UV_CACHE_DIR=$PROJECT_DIRECTORY/.local/uv/cache +DRIVERS_TOOLS_BINARIES="$DRIVERS_TOOLS/.bin" +MONGODB_BINARIES="$DRIVERS_TOOLS/mongodb/bin" + +# On Evergreen jobs, "CI" will be set, and we don't want to write to $HOME. +if [ "${CI:-}" == "true" ]; then + PYMONGO_BIN_DIR=${DRIVERS_TOOLS_BINARIES:-} +# We want to use a path that's already on PATH on spawn hosts. +else + PYMONGO_BIN_DIR=$HOME/cli_bin +fi + +PATH_EXT="$MONGODB_BINARIES:$DRIVERS_TOOLS_BINARIES:$PYMONGO_BIN_DIR:\$PATH" # Python has cygwin path problems on Windows. Detect prospective mongo-orchestration home directory if [ "Windows_NT" = "${OS:-}" ]; then # Magic variable in cygwin @@ -24,6 +36,9 @@ if [ "Windows_NT" = "${OS:-}" ]; then # Magic variable in cygwin CARGO_HOME=$(cygpath -m $CARGO_HOME) UV_TOOL_DIR=$(cygpath -m "$UV_TOOL_DIR") UV_CACHE_DIR=$(cygpath -m "$UV_CACHE_DIR") + DRIVERS_TOOLS_BINARIES=$(cygpath -m "$DRIVERS_TOOLS_BINARIES") + MONGODB_BINARIES=$(cygpath -m "$MONGODB_BINARIES") + PYMONGO_BIN_DIR=$(cygpath -m "$PYMONGO_BIN_DIR") fi SCRIPT_DIR="$PROJECT_DIRECTORY/.evergreen/scripts" @@ -36,50 +51,63 @@ fi export MONGO_ORCHESTRATION_HOME="$DRIVERS_TOOLS/.evergreen/orchestration" export MONGODB_BINARIES="$DRIVERS_TOOLS/mongodb/bin" -export DRIVERS_TOOLS_BINARIES="$DRIVERS_TOOLS/.bin" cat < "$SCRIPT_DIR"/env.sh export PROJECT_DIRECTORY="$PROJECT_DIRECTORY" export CURRENT_VERSION="$CURRENT_VERSION" -export SKIP_LEGACY_SHELL=1 export DRIVERS_TOOLS="$DRIVERS_TOOLS" export MONGO_ORCHESTRATION_HOME="$MONGO_ORCHESTRATION_HOME" export MONGODB_BINARIES="$MONGODB_BINARIES" export DRIVERS_TOOLS_BINARIES="$DRIVERS_TOOLS_BINARIES" export PROJECT_DIRECTORY="$PROJECT_DIRECTORY" -export SETDEFAULTENCODING="${SETDEFAULTENCODING:-}" -export SKIP_CSOT_TESTS="${SKIP_CSOT_TESTS:-}" -export MONGODB_STARTED="${MONGODB_STARTED:-}" -export DISABLE_TEST_COMMANDS="${DISABLE_TEST_COMMANDS:-}" -export GREEN_FRAMEWORK="${GREEN_FRAMEWORK:-}" -export NO_EXT="${NO_EXT:-}" -export COVERAGE="${COVERAGE:-}" -export COMPRESSORS="${COMPRESSORS:-}" -export MONGODB_API_VERSION="${MONGODB_API_VERSION:-}" -export skip_crypt_shared="${skip_crypt_shared:-}" -export STORAGE_ENGINE="${STORAGE_ENGINE:-}" -export REQUIRE_API_VERSION="${REQUIRE_API_VERSION:-}" -export skip_web_identity_auth_test="${skip_web_identity_auth_test:-}" -export skip_ECS_auth_test="${skip_ECS_auth_test:-}" export CARGO_HOME="$CARGO_HOME" -export TMPDIR="$MONGO_ORCHESTRATION_HOME/db" export UV_TOOL_DIR="$UV_TOOL_DIR" export UV_CACHE_DIR="$UV_CACHE_DIR" export UV_TOOL_BIN_DIR="$DRIVERS_TOOLS_BINARIES" -export PATH="$MONGODB_BINARIES:$DRIVERS_TOOLS_BINARIES:$PATH" +export PYMONGO_BIN_DIR="$PYMONGO_BIN_DIR" +export PATH="$PATH_EXT" # shellcheck disable=SC2154 export PROJECT="${project:-mongo-python-driver}" export PIP_QUIET=1 EOT -# Skip CSOT tests on non-linux platforms. -if [ "$(uname -s)" != "Linux" ]; then - echo "export SKIP_CSOT_TESTS=1" >> $SCRIPT_DIR/env.sh -fi +# Write the .env file for drivers-tools. +rm -rf $DRIVERS_TOOLS +BRANCH=master +ORG=mongodb-labs +git clone --branch $BRANCH https://github.com/$ORG/drivers-evergreen-tools.git $DRIVERS_TOOLS + +cat < ${DRIVERS_TOOLS}/.env +SKIP_LEGACY_SHELL=1 +DRIVERS_TOOLS="$DRIVERS_TOOLS" +MONGO_ORCHESTRATION_HOME="$MONGO_ORCHESTRATION_HOME" +MONGODB_BINARIES="$MONGODB_BINARIES" +EOT # Add these expansions to make it easier to call out tests scripts from the EVG yaml cat < expansion.yml DRIVERS_TOOLS: "$DRIVERS_TOOLS" PROJECT_DIRECTORY: "$PROJECT_DIRECTORY" EOT + +# If the toolchain is available, symlink binaries to the bin dir. This has to be done +# after drivers-tools is cloned, since we might be using its binary dir. +_bin_path="" +if [ "Windows_NT" == "${OS:-}" ]; then + _bin_path="/cygdrive/c/Python/Current/Scripts" +elif [ "$(uname -s)" != "Darwin" ]; then + _bin_path="/Library/Frameworks/Python.Framework/Versions/Current/bin" +else + _bin_path="/opt/python/Current/bin" +fi +if [ -d "${_bin_path}" ]; then + _suffix="" + if [ "Windows_NT" == "${OS:-}" ]; then + _suffix=".exe" + fi + mkdir -p $PYMONGO_BIN_DIR + ln -s ${_bin_path}/just${_suffix} $PYMONGO_BIN_DIR/just${_suffix} + ln -s ${_bin_path}/uv${_suffix} $PYMONGO_BIN_DIR/uv${_suffix} + ln -s ${_bin_path}/uvx${_suffix} $PYMONGO_BIN_DIR/uvx${_suffix} +fi diff --git a/.evergreen/scripts/download-and-merge-coverage.sh b/.evergreen/scripts/download-and-merge-coverage.sh index 808bb957ef..c006813ba9 100755 --- a/.evergreen/scripts/download-and-merge-coverage.sh +++ b/.evergreen/scripts/download-and-merge-coverage.sh @@ -1,4 +1,4 @@ #!/bin/bash - # Download all the task coverage files. +set -eu aws s3 cp --recursive s3://"$1"/coverage/"$2"/"$3"/coverage/ coverage/ diff --git a/.evergreen/scripts/fix-absolute-paths.sh b/.evergreen/scripts/fix-absolute-paths.sh deleted file mode 100755 index eb9433c673..0000000000 --- a/.evergreen/scripts/fix-absolute-paths.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash - -set +x -. src/.evergreen/scripts/env.sh -# shellcheck disable=SC2044 -for filename in $(find $DRIVERS_TOOLS -name \*.json); do - perl -p -i -e "s|ABSOLUTE_PATH_REPLACEMENT_TOKEN|$DRIVERS_TOOLS|g" $filename -done diff --git a/.evergreen/scripts/generate-config.sh b/.evergreen/scripts/generate-config.sh new file mode 100755 index 0000000000..70b4578cf9 --- /dev/null +++ b/.evergreen/scripts/generate-config.sh @@ -0,0 +1,6 @@ +#!/bin/bash +# Entry point for the generate-config pre-commit hook. + +set -eu + +python .evergreen/scripts/generate_config.py diff --git a/.evergreen/scripts/generate_config.py b/.evergreen/scripts/generate_config.py index e9624ab109..51ab3814cc 100644 --- a/.evergreen/scripts/generate_config.py +++ b/.evergreen/scripts/generate_config.py @@ -1,12 +1,4 @@ -# /// script -# requires-python = ">=3.9" -# dependencies = [ -# "shrub.py>=3.2.0", -# "pyyaml>=6.0.2" -# ] -# /// - -# Note: Run this file with `pipx run`, or `uv run`. +# Note: See CONTRIBUTING.md for how to update/run this file. from __future__ import annotations import sys @@ -17,16 +9,23 @@ from typing import Any from shrub.v3.evg_build_variant import BuildVariant -from shrub.v3.evg_command import FunctionCall +from shrub.v3.evg_command import ( + EvgCommandType, + FunctionCall, + archive_targz_pack, + ec2_assume_role, + s3_put, + subprocess_exec, +) from shrub.v3.evg_project import EvgProject -from shrub.v3.evg_task import EvgTask, EvgTaskRef +from shrub.v3.evg_task import EvgTask, EvgTaskDependency, EvgTaskRef from shrub.v3.shrub_service import ShrubService ############## # Globals ############## -ALL_VERSIONS = ["4.0", "4.4", "5.0", "6.0", "7.0", "8.0", "rapid", "latest"] +ALL_VERSIONS = ["4.0", "4.2", "4.4", "5.0", "6.0", "7.0", "8.0", "rapid", "latest"] CPYTHONS = ["3.9", "3.10", "3.11", "3.12", "3.13"] PYPYS = ["pypy3.10"] ALL_PYTHONS = CPYTHONS + PYPYS @@ -34,7 +33,7 @@ BATCHTIME_WEEK = 10080 AUTH_SSLS = [("auth", "ssl"), ("noauth", "ssl"), ("noauth", "nossl")] TOPOLOGIES = ["standalone", "replica_set", "sharded_cluster"] -C_EXTS = ["with_ext", "without_ext"] +C_EXTS = ["without_ext", "with_ext"] # By default test each of the topologies with a subset of auth/ssl. SUB_TASKS = [ ".sharded_cluster .auth .ssl", @@ -69,12 +68,21 @@ class Host: HOSTS["ubuntu20"] = Host("ubuntu20", "ubuntu2004-small", "Ubuntu-20", dict()) HOSTS["ubuntu22"] = Host("ubuntu22", "ubuntu2204-small", "Ubuntu-22", dict()) HOSTS["rhel7"] = Host("rhel7", "rhel79-small", "RHEL7", dict()) +HOSTS["perf"] = Host("perf", "rhel90-dbx-perf-large", "", dict()) +HOSTS["debian11"] = Host("debian11", "debian11-small", "Debian11", dict()) DEFAULT_HOST = HOSTS["rhel8"] # Other hosts -OTHER_HOSTS = ["RHEL9-FIPS", "RHEL8-zseries", "RHEL8-POWER8", "RHEL8-arm64"] +OTHER_HOSTS = ["RHEL9-FIPS", "RHEL8-zseries", "RHEL8-POWER8", "RHEL8-arm64", "Amazon2023"] for name, run_on in zip( - OTHER_HOSTS, ["rhel92-fips", "rhel8-zseries-small", "rhel8-power-small", "rhel82-arm64-small"] + OTHER_HOSTS, + [ + "rhel92-fips", + "rhel8-zseries-small", + "rhel8-power-small", + "rhel82-arm64-small", + "amazon2023-arm64-latest-large-m8g", + ], ): HOSTS[name] = Host(name, run_on, name, dict()) @@ -85,7 +93,7 @@ class Host: def create_variant_generic( - task_names: list[str], + tasks: list[str | EvgTaskRef], display_name: str, *, host: Host | None = None, @@ -94,7 +102,12 @@ def create_variant_generic( **kwargs: Any, ) -> BuildVariant: """Create a build variant for the given inputs.""" - task_refs = [EvgTaskRef(name=n) for n in task_names] + task_refs = [] + for t in tasks: + if isinstance(t, EvgTaskRef): + task_refs.append(t) + else: + task_refs.append(EvgTaskRef(name=t)) expansions = expansions and expansions.copy() or dict() if "run_on" in kwargs: run_on = kwargs.pop("run_on") @@ -118,7 +131,7 @@ def create_variant_generic( def create_variant( - task_names: list[str], + tasks: list[str | EvgTaskRef], display_name: str, *, version: str | None = None, @@ -133,7 +146,7 @@ def create_variant( if python: expansions["PYTHON_BINARY"] = get_python_binary(python, host) return create_variant_generic( - task_names, display_name, version=version, host=host, expansions=expansions, **kwargs + tasks, display_name, version=version, host=host, expansions=expansions, **kwargs ) @@ -179,17 +192,14 @@ def get_versions_until(max_version: str) -> list[str]: return versions -def get_display_name(base: str, host: Host | None = None, **kwargs) -> str: - """Get the display name of a variant.""" +def get_common_name(base: str, sep: str, **kwargs) -> str: display_name = base - if host is not None: - display_name += f" {host.display_name}" version = kwargs.pop("VERSION", None) version = version or kwargs.pop("version", None) if version: if version not in ["rapid", "latest"]: version = f"v{version}" - display_name = f"{display_name} {version}" + display_name = f"{display_name}{sep}{version}" for key, value in kwargs.items(): name = value if key.lower() == "python": @@ -201,10 +211,22 @@ def get_display_name(base: str, host: Host | None = None, **kwargs) -> str: name = DISPLAY_LOOKUP[key.lower()][value] else: continue - display_name = f"{display_name} {name}" + display_name = f"{display_name}{sep}{name}" return display_name +def get_variant_name(base: str, host: Host | None = None, **kwargs) -> str: + """Get the display name of a variant.""" + display_name = base + if host is not None: + display_name += f" {host.display_name}" + return get_common_name(display_name, " ", **kwargs) + + +def get_task_name(base: str, **kwargs): + return get_common_name(base, "-", **kwargs).replace(" ", "-").lower() + + def zip_cycle(*iterables, empty_default=None): """Get all combinations of the inputs, cycling over the shorter list(s).""" cycles = [cycle(i) for i in iterables] @@ -212,12 +234,37 @@ def zip_cycle(*iterables, empty_default=None): yield tuple(next(i, empty_default) for i in cycles) -def handle_c_ext(c_ext, expansions): +def handle_c_ext(c_ext, expansions) -> None: """Handle c extension option.""" if c_ext == C_EXTS[0]: expansions["NO_EXT"] = "1" +def get_assume_role(**kwargs): + kwargs.setdefault("command_type", EvgCommandType.SETUP) + kwargs.setdefault("role_arn", "${assume_role_arn}") + return ec2_assume_role(**kwargs) + + +def get_subprocess_exec(**kwargs): + kwargs.setdefault("binary", "bash") + kwargs.setdefault("working_dir", "src") + kwargs.setdefault("command_type", EvgCommandType.TEST) + return subprocess_exec(**kwargs) + + +def get_s3_put(**kwargs): + kwargs["aws_key"] = "${AWS_ACCESS_KEY_ID}" + kwargs["aws_secret"] = "${AWS_SECRET_ACCESS_KEY}" # noqa:S105 + kwargs["aws_session_token"] = "${AWS_SESSION_TOKEN}" # noqa:S105 + kwargs["bucket"] = "${bucket_name}" + kwargs.setdefault("optional", "true") + kwargs.setdefault("permissions", "public-read") + kwargs.setdefault("content_type", "${content_type|application/x-gzip}") + kwargs.setdefault("command_type", EvgCommandType.SETUP) + return s3_put(**kwargs) + + def generate_yaml(tasks=None, variants=None): """Generate the yaml for a given set of tasks and variants.""" project = EvgProject(tasks=tasks, buildvariants=variants) @@ -234,41 +281,22 @@ def generate_yaml(tasks=None, variants=None): def create_ocsp_variants() -> list[BuildVariant]: variants = [] - batchtime = BATCHTIME_WEEK * 2 - expansions = dict(AUTH="noauth", SSL="ssl", TOPOLOGY="server") - base_display = "OCSP" - - # OCSP tests on default host with all servers v4.4+ and all python versions. - versions = [v for v in ALL_VERSIONS if v != "4.0"] - for version, python in zip_cycle(versions, ALL_PYTHONS): - host = DEFAULT_HOST - variant = create_variant( - [".ocsp"], - get_display_name(base_display, host, version=version, python=python), - python=python, - version=version, - host=host, - expansions=expansions, - batchtime=batchtime, - ) - variants.append(variant) - - # OCSP tests on Windows and MacOS. - # MongoDB servers on these hosts do not staple OCSP responses and only support RSA. - for host_name, version in product(["win64", "macos"], ["4.4", "8.0"]): + # OCSP tests on default host with all servers v4.4+. + # MongoDB servers on Windows and MacOS do not staple OCSP responses and only support RSA. + # Only test with MongoDB 4.4 and latest. + for host_name in ["rhel8", "win64", "macos"]: host = HOSTS[host_name] - python = CPYTHONS[0] if version == "4.4" else CPYTHONS[-1] + if host == DEFAULT_HOST: + tasks = [".ocsp"] + else: + tasks = [".ocsp-rsa !.ocsp-staple .latest", ".ocsp-rsa !.ocsp-staple .4.4"] variant = create_variant( - [".ocsp-rsa !.ocsp-staple"], - get_display_name(base_display, host, version=version, python=python), - python=python, - version=version, + tasks, + get_variant_name("OCSP", host), host=host, - expansions=expansions, - batchtime=batchtime, + batchtime=BATCHTIME_WEEK, ) variants.append(variant) - return variants @@ -279,9 +307,10 @@ def create_server_variants() -> list[BuildVariant]: host = DEFAULT_HOST # Prefix the display name with an asterisk so it is sorted first. base_display_name = "* Test" - for python in [*MIN_MAX_PYTHON, PYPYS[-1]]: + for python, c_ext in product([*MIN_MAX_PYTHON, PYPYS[-1]], C_EXTS): expansions = dict(COVERAGE="coverage") - display_name = get_display_name(base_display_name, host, python=python, **expansions) + handle_c_ext(c_ext, expansions) + display_name = get_variant_name(base_display_name, host, python=python, **expansions) variant = create_variant( [f".{t} .sync_async" for t in TOPOLOGIES], display_name, @@ -295,7 +324,7 @@ def create_server_variants() -> list[BuildVariant]: # Test the rest of the pythons. for python in CPYTHONS[1:-1] + PYPYS[:-1]: display_name = f"Test {host}" - display_name = get_display_name(base_display_name, host, python=python) + display_name = get_variant_name(base_display_name, host, python=python) variant = create_variant( [f"{t} .sync_async" for t in SUB_TASKS], display_name, @@ -315,7 +344,7 @@ def create_server_variants() -> list[BuildVariant]: for version in get_versions_from("6.0"): tasks.extend(f"{t} .{version} !.sync_async" for t in SUB_TASKS) host = HOSTS[host_name] - display_name = get_display_name(base_display_name, host, python=python) + display_name = get_variant_name(base_display_name, host, python=python) variant = create_variant(tasks, display_name, python=python, host=host) variants.append(variant) @@ -331,7 +360,7 @@ def create_free_threaded_variants() -> list[BuildVariant]: tasks = [".free-threading"] host = HOSTS[host_name] python = "3.13t" - display_name = get_display_name("Free-threaded", host, python=python) + display_name = get_variant_name("Free-threaded", host, python=python) variant = create_variant(tasks, display_name, python=python, host=host) variants.append(variant) return variants @@ -343,20 +372,20 @@ def create_encryption_variants() -> list[BuildVariant]: batchtime = BATCHTIME_WEEK def get_encryption_expansions(encryption): - expansions = dict(test_encryption="true") + expansions = dict(TEST_NAME="encryption") if "crypt_shared" in encryption: - expansions["test_crypt_shared"] = "true" + expansions["TEST_CRYPT_SHARED"] = "true" if "PyOpenSSL" in encryption: - expansions["test_encryption_pyopenssl"] = "true" + expansions["SUB_TEST_NAME"] = "pyopenssl" return expansions host = DEFAULT_HOST # Test against all server versions for the three main python versions. - encryptions = ["Encryption", "Encryption crypt_shared", "Encryption PyOpenSSL"] + encryptions = ["Encryption", "Encryption crypt_shared"] for encryption, python in product(encryptions, [*MIN_MAX_PYTHON, PYPYS[-1]]): expansions = get_encryption_expansions(encryption) - display_name = get_display_name(encryption, host, python=python, **expansions) + display_name = get_variant_name(encryption, host, python=python, **expansions) variant = create_variant( [f"{t} .sync_async" for t in SUB_TASKS], display_name, @@ -368,10 +397,25 @@ def get_encryption_expansions(encryption): ) variants.append(variant) + # Test PyOpenSSL against on all server versions for all python versions. + for encryption, python in product(["Encryption PyOpenSSL"], [*MIN_MAX_PYTHON, PYPYS[-1]]): + expansions = get_encryption_expansions(encryption) + display_name = get_variant_name(encryption, host, python=python, **expansions) + variant = create_variant( + [f"{t} .sync" for t in SUB_TASKS], + display_name, + python=python, + host=host, + expansions=expansions, + batchtime=batchtime, + tags=tags, + ) + variants.append(variant) + # Test the rest of the pythons on linux for all server versions. for encryption, python, task in zip_cycle(encryptions, CPYTHONS[1:-1] + PYPYS[:-1], SUB_TASKS): expansions = get_encryption_expansions(encryption) - display_name = get_display_name(encryption, host, python=python, **expansions) + display_name = get_variant_name(encryption, host, python=python, **expansions) variant = create_variant( [f"{task} .sync_async"], display_name, @@ -387,7 +431,7 @@ def get_encryption_expansions(encryption): for host_name, encryption, python in product(["macos", "win64"], encryptions, MIN_MAX_PYTHON): host = HOSTS[host_name] expansions = get_encryption_expansions(encryption) - display_name = get_display_name(encryption, host, python=python, **expansions) + display_name = get_variant_name(encryption, host, python=python, **expansions) variant = create_variant( task_names, display_name, @@ -403,81 +447,42 @@ def get_encryption_expansions(encryption): def create_load_balancer_variants(): # Load balancer tests - run all supported server versions using the lowest supported python. - host = DEFAULT_HOST - batchtime = BATCHTIME_WEEK - versions = get_versions_from("6.0") - variants = [] - for version in versions: - python = CPYTHONS[0] - display_name = get_display_name("Load Balancer", host, python=python, version=version) - variant = create_variant( - [".load-balancer"], - display_name, - python=python, - host=host, - version=version, - batchtime=batchtime, + return [ + create_variant( + [".load-balancer"], "Load Balancer", host=DEFAULT_HOST, batchtime=BATCHTIME_WEEK ) - variants.append(variant) - return variants + ] def create_compression_variants(): - # Compression tests - standalone versions of each server, across python versions, with and without c extensions. - # PyPy interpreters are always tested without extensions. + # Compression tests - standalone versions of each server, across python versions. host = DEFAULT_HOST - base_task = ".standalone .noauth .nossl .sync_async" - task_names = dict(snappy=[base_task], zlib=[base_task], zstd=[f"{base_task} !.4.0"]) + base_task = ".compression" variants = [] - for ind, (compressor, c_ext) in enumerate(product(["snappy", "zlib", "zstd"], C_EXTS)): - expansions = dict(COMPRESSORS=compressor) - handle_c_ext(c_ext, expansions) - base_name = f"Compression {compressor}" - python = CPYTHONS[ind % len(CPYTHONS)] - display_name = get_display_name(base_name, host, python=python, **expansions) - variant = create_variant( - task_names[compressor], - display_name, - python=python, - host=host, - expansions=expansions, - ) - variants.append(variant) - - other_pythons = PYPYS + CPYTHONS[ind:] - for compressor, python in zip_cycle(["snappy", "zlib", "zstd"], other_pythons): - expansions = dict(COMPRESSORS=compressor) - handle_c_ext(c_ext, expansions) - base_name = f"Compression {compressor}" - display_name = get_display_name(base_name, host, python=python, **expansions) - variant = create_variant( - task_names[compressor], - display_name, - python=python, - host=host, - expansions=expansions, + for compressor in "snappy", "zlib", "zstd": + expansions = dict(COMPRESSOR=compressor) + tasks = [base_task] if compressor != "zstd" else [f"{base_task} !.4.0"] + display_name = get_variant_name(f"Compression {compressor}", host) + variants.append( + create_variant( + tasks, + display_name, + host=host, + expansions=expansions, + ) ) - variants.append(variant) - return variants def create_enterprise_auth_variants(): - expansions = dict(AUTH="auth") variants = [] - - # All python versions across platforms. - for python in ALL_PYTHONS: - if python == CPYTHONS[0]: - host = HOSTS["macos"] - elif python == CPYTHONS[-1]: - host = HOSTS["win64"] + for host in [HOSTS["macos"], HOSTS["win64"], DEFAULT_HOST]: + display_name = get_variant_name("Auth Enterprise", host) + if host == DEFAULT_HOST: + tags = [".enterprise_auth"] else: - host = DEFAULT_HOST - display_name = get_display_name("Auth Enterprise", host, python=python, **expansions) - variant = create_variant( - ["test-enterprise-auth"], display_name, host=host, python=python, expansions=expansions - ) + tags = [".enterprise_auth !.pypy"] + variant = create_variant(tags, display_name, host=host) variants.append(variant) return variants @@ -486,7 +491,7 @@ def create_enterprise_auth_variants(): def create_pyopenssl_variants(): base_name = "PyOpenSSL" batchtime = BATCHTIME_WEEK - expansions = dict(test_pyopenssl="true") + expansions = dict(TEST_NAME="default", SUB_TEST_NAME="pyopenssl") variants = [] for python in ALL_PYTHONS: @@ -500,15 +505,26 @@ def create_pyopenssl_variants(): else: host = DEFAULT_HOST - display_name = get_display_name(base_name, host, python=python) - variant = create_variant( - [f".replica_set .{auth} .{ssl} .sync_async", f".7.0 .{auth} .{ssl} .sync_async"], - display_name, - python=python, - host=host, - expansions=expansions, - batchtime=batchtime, - ) + display_name = get_variant_name(base_name, host, python=python) + # only need to run some on async + if python in (CPYTHONS[1], CPYTHONS[-1]): + variant = create_variant( + [f".replica_set .{auth} .{ssl} .sync_async", f".7.0 .{auth} .{ssl} .sync_async"], + display_name, + python=python, + host=host, + expansions=expansions, + batchtime=batchtime, + ) + else: + variant = create_variant( + [f".replica_set .{auth} .{ssl} .sync", f".7.0 .{auth} .{ssl} .sync"], + display_name, + python=python, + host=host, + expansions=expansions, + batchtime=batchtime, + ) variants.append(variant) return variants @@ -529,7 +545,7 @@ def create_storage_engine_variants(): tasks = [f".standalone .{v} .noauth .nossl .sync_async" for v in versions] + [ f".replica_set .{v} .noauth .nossl .sync_async" for v in versions ] - display_name = get_display_name(f"Storage {engine}", host, python=python) + display_name = get_variant_name(f"Storage {engine}", host, python=python) variant = create_variant( tasks, display_name, host=host, python=python, expansions=expansions ) @@ -540,7 +556,6 @@ def create_storage_engine_variants(): def create_stable_api_variants(): host = DEFAULT_HOST tags = ["versionedApi_tag"] - tasks = [f".standalone .{v} .noauth .nossl .sync_async" for v in get_versions_from("5.0")] variants = [] types = ["require v1", "accept v2"] @@ -554,13 +569,19 @@ def create_stable_api_variants(): expansions["REQUIRE_API_VERSION"] = "1" # MONGODB_API_VERSION is the apiVersion to use in the test suite. expansions["MONGODB_API_VERSION"] = "1" + tasks = [ + f"!.replica_set .{v} .noauth .nossl .sync_async" for v in get_versions_from("5.0") + ] else: # Test against a cluster with acceptApiVersion2 but without # requireApiVersion, and don't automatically add apiVersion to # clients created in the test suite. expansions["ORCHESTRATION_FILE"] = "versioned-api-testing.json" + tasks = [ + f".standalone .{v} .noauth .nossl .sync_async" for v in get_versions_from("5.0") + ] base_display_name = f"Stable API {test_type}" - display_name = get_display_name(base_display_name, host, python=python, **expansions) + display_name = get_variant_name(base_display_name, host, python=python, **expansions) variant = create_variant( tasks, display_name, host=host, python=python, tags=tags, expansions=expansions ) @@ -575,7 +596,7 @@ def create_green_framework_variants(): host = DEFAULT_HOST for python, framework in product([CPYTHONS[0], CPYTHONS[-1]], ["eventlet", "gevent"]): expansions = dict(GREEN_FRAMEWORK=framework, AUTH="auth", SSL="ssl") - display_name = get_display_name(f"Green {framework.capitalize()}", host, python=python) + display_name = get_variant_name(f"Green {framework.capitalize()}", host, python=python) variant = create_variant( tasks, display_name, host=host, python=python, expansions=expansions ) @@ -587,10 +608,10 @@ def create_no_c_ext_variants(): variants = [] host = DEFAULT_HOST for python, topology in zip_cycle(CPYTHONS, TOPOLOGIES): - tasks = [f".{topology} .noauth .nossl .sync_async"] + tasks = [f".{topology} .noauth .nossl !.sync_async"] expansions = dict() handle_c_ext(C_EXTS[0], expansions) - display_name = get_display_name("No C Ext", host, python=python) + display_name = get_variant_name("No C Ext", host, python=python) variant = create_variant( tasks, display_name, host=host, python=python, expansions=expansions ) @@ -601,14 +622,10 @@ def create_no_c_ext_variants(): def create_atlas_data_lake_variants(): variants = [] host = HOSTS["ubuntu22"] - for python, c_ext in product(MIN_MAX_PYTHON, C_EXTS): - tasks = ["atlas-data-lake-tests"] - expansions = dict(AUTH="auth") - handle_c_ext(c_ext, expansions) - display_name = get_display_name("Atlas Data Lake", host, python=python, **expansions) - variant = create_variant( - tasks, display_name, host=host, python=python, expansions=expansions - ) + for python in MIN_MAX_PYTHON: + tasks = [".atlas_data_lake"] + display_name = get_variant_name("Atlas Data Lake", host, python=python) + variant = create_variant(tasks, display_name, host=host, python=python) variants.append(variant) return variants @@ -616,15 +633,10 @@ def create_atlas_data_lake_variants(): def create_mod_wsgi_variants(): variants = [] host = HOSTS["ubuntu22"] - tasks = [ - "mod-wsgi-standalone", - "mod-wsgi-replica-set", - "mod-wsgi-embedded-mode-standalone", - "mod-wsgi-embedded-mode-replica-set", - ] + tasks = [".mod_wsgi"] expansions = dict(MOD_WSGI_VERSION="4") for python in MIN_MAX_PYTHON: - display_name = get_display_name("mod_wsgi", host, python=python) + display_name = get_variant_name("mod_wsgi", host, python=python) variant = create_variant( tasks, display_name, host=host, python=python, expansions=expansions ) @@ -636,7 +648,7 @@ def create_disable_test_commands_variants(): host = DEFAULT_HOST expansions = dict(AUTH="auth", SSL="ssl", DISABLE_TEST_COMMANDS="1") python = CPYTHONS[0] - display_name = get_display_name("Disable test commands", host, python=python) + display_name = get_variant_name("Disable test commands", host, python=python) tasks = [".latest .sync_async"] return [create_variant(tasks, display_name, host=host, python=python, expansions=expansions)] @@ -644,16 +656,14 @@ def create_disable_test_commands_variants(): def create_serverless_variants(): host = DEFAULT_HOST batchtime = BATCHTIME_WEEK - expansions = dict(test_serverless="true", AUTH="auth", SSL="ssl") - tasks = ["serverless_task_group"] + tasks = [".serverless"] base_name = "Serverless" return [ create_variant( tasks, - get_display_name(base_name, host, python=python), + get_variant_name(base_name, host, python=python), host=host, python=python, - expansions=expansions, batchtime=batchtime, ) for python in MIN_MAX_PYTHON @@ -662,18 +672,18 @@ def create_serverless_variants(): def create_oidc_auth_variants(): variants = [] - other_tasks = ["testazureoidc_task_group", "testgcpoidc_task_group", "testk8soidc_task_group"] for host_name in ["ubuntu22", "macos", "win64"]: - tasks = ["testoidc_task_group"] if host_name == "ubuntu22": - tasks += other_tasks + tasks = [".auth_oidc"] + else: + tasks = [".auth_oidc !.auth_oidc_remote"] host = HOSTS[host_name] variants.append( create_variant( tasks, - get_display_name("Auth OIDC", host), + get_variant_name("Auth OIDC", host), host=host, - batchtime=BATCHTIME_WEEK * 2, + batchtime=BATCHTIME_WEEK, ) ) return variants @@ -684,8 +694,8 @@ def create_search_index_variants(): python = CPYTHONS[0] return [ create_variant( - ["test_atlas_task_group_search_indexes"], - get_display_name("Search Index Helpers", host, python=python), + [".search_index"], + get_variant_name("Search Index Helpers", host, python=python), python=python, host=host, ) @@ -697,8 +707,8 @@ def create_mockupdb_variants(): python = CPYTHONS[0] return [ create_variant( - ["mockupdb"], - get_display_name("MockupDB", host, python=python), + [".mockupdb"], + get_variant_name("MockupDB", host, python=python), python=python, host=host, ) @@ -710,8 +720,8 @@ def create_doctests_variants(): python = CPYTHONS[0] return [ create_variant( - ["doctests"], - get_display_name("Doctests", host, python=python), + [".doctests"], + get_variant_name("Doctests", host, python=python), python=python, host=host, ) @@ -722,8 +732,8 @@ def create_atlas_connect_variants(): host = DEFAULT_HOST return [ create_variant( - ["atlas-connect"], - get_display_name("Atlas connect", host, python=python), + [".atlas_connect"], + get_variant_name("Atlas connect", host, python=python), python=python, host=host, ) @@ -731,29 +741,48 @@ def create_atlas_connect_variants(): ] +def create_coverage_report_variants(): + return [create_variant(["coverage-report"], "Coverage Report", host=DEFAULT_HOST)] + + +def create_kms_variants(): + tasks = [] + tasks.append(EvgTaskRef(name="test-gcpkms", batchtime=BATCHTIME_WEEK)) + tasks.append("test-gcpkms-fail") + tasks.append(EvgTaskRef(name="test-azurekms", batchtime=BATCHTIME_WEEK)) + tasks.append("test-azurekms-fail") + return [create_variant(tasks, "KMS", host=HOSTS["debian11"])] + + +def create_import_time_variants(): + return [create_variant(["check-import-time"], "Import Time", host=DEFAULT_HOST)] + + +def create_backport_pr_variants(): + return [create_variant(["backport-pr"], "Backport PR", host=DEFAULT_HOST)] + + +def create_perf_variants(): + host = HOSTS["perf"] + return [ + create_variant([".perf"], "Performance Benchmarks", host=host, batchtime=BATCHTIME_WEEK) + ] + + def create_aws_auth_variants(): variants = [] - tasks = [ - "aws-auth-test-4.4", - "aws-auth-test-5.0", - "aws-auth-test-6.0", - "aws-auth-test-7.0", - "aws-auth-test-8.0", - "aws-auth-test-rapid", - "aws-auth-test-latest", - ] for host_name, python in product(["ubuntu20", "win64", "macos"], MIN_MAX_PYTHON): expansions = dict() - if host_name != "ubuntu20": - expansions["skip_ECS_auth_test"] = "true" + tasks = [".auth-aws"] if host_name == "macos": - expansions["skip_EC2_auth_test"] = "true" - expansions["skip_web_identity_auth_test"] = "true" + tasks = [".auth-aws !.auth-aws-web-identity !.auth-aws-ecs !.auth-aws-ec2"] + elif host_name == "win64": + tasks = [".auth-aws !.auth-aws-ecs"] host = HOSTS[host_name] variant = create_variant( tasks, - get_display_name("Auth AWS", host, python=python), + get_variant_name("Auth AWS", host, python=python), host=host, python=python, expansions=expansions, @@ -762,6 +791,11 @@ def create_aws_auth_variants(): return variants +def create_no_server_variants(): + host = HOSTS["rhel8"] + return [create_variant([".no-server"], "No server", host=host)] + + def create_alternative_hosts_variants(): batchtime = BATCHTIME_WEEK variants = [] @@ -770,7 +804,7 @@ def create_alternative_hosts_variants(): variants.append( create_variant( [".5.0 .standalone !.sync_async"], - get_display_name("OpenSSL 1.0.2", host, python=CPYTHONS[0]), + get_variant_name("OpenSSL 1.0.2", host, python=CPYTHONS[0]), host=host, python=CPYTHONS[0], batchtime=batchtime, @@ -781,10 +815,13 @@ def create_alternative_hosts_variants(): handle_c_ext(C_EXTS[0], expansions) for host_name in OTHER_HOSTS: host = HOSTS[host_name] + tags = [".6.0 .standalone !.sync_async"] + if host_name == "Amazon2023": + tags = [f".latest !.sync_async {t}" for t in SUB_TASKS] variants.append( create_variant( - [".6.0 .standalone !.sync_async"], - display_name=get_display_name("Other hosts", host), + tags, + display_name=get_variant_name("Other hosts", host), batchtime=batchtime, host=host, expansions=expansions, @@ -793,6 +830,11 @@ def create_alternative_hosts_variants(): return variants +def create_aws_lambda_variants(): + host = HOSTS["rhel8"] + return [create_variant([".aws_lambda"], display_name="FaaS Lambda", host=host)] + + ############## # Tasks ############## @@ -803,45 +845,463 @@ def create_server_tasks(): for topo, version, (auth, ssl), sync in product(TOPOLOGIES, ALL_VERSIONS, AUTH_SSLS, SYNCS): name = f"test-{version}-{topo}-{auth}-{ssl}-{sync}".lower() tags = [version, topo, auth, ssl, sync] - bootstrap_vars = dict( + server_vars = dict( VERSION=version, TOPOLOGY=topo if topo != "standalone" else "server", AUTH=auth, SSL=ssl, ) - bootstrap_func = FunctionCall(func="bootstrap mongo-orchestration", vars=bootstrap_vars) - test_suites = "" + server_func = FunctionCall(func="run server", vars=server_vars) + test_vars = dict(AUTH=auth, SSL=ssl, SYNC=sync) if sync == "sync": - test_suites = "default" + test_vars["TEST_NAME"] = "default_sync" elif sync == "async": - test_suites = "default_async" - test_vars = dict( + test_vars["TEST_NAME"] = "default_async" + test_func = FunctionCall(func="run tests", vars=test_vars) + tasks.append(EvgTask(name=name, tags=tags, commands=[server_func, test_func])) + return tasks + + +def create_load_balancer_tasks(): + tasks = [] + for (auth, ssl), version in product(AUTH_SSLS, get_versions_from("6.0")): + name = get_task_name(f"test-load-balancer-{auth}-{ssl}", version=version) + tags = ["load-balancer", auth, ssl] + server_vars = dict( + TOPOLOGY="sharded_cluster", AUTH=auth, SSL=ssl, - SYNC=sync, - TEST_SUITES=test_suites, + TEST_NAME="load_balancer", + VERSION=version, ) + server_func = FunctionCall(func="run server", vars=server_vars) + test_vars = dict(AUTH=auth, SSL=ssl, TEST_NAME="load_balancer") test_func = FunctionCall(func="run tests", vars=test_vars) - tasks.append(EvgTask(name=name, tags=tags, commands=[bootstrap_func, test_func])) + tasks.append(EvgTask(name=name, tags=tags, commands=[server_func, test_func])) + return tasks -def create_load_balancer_tasks(): +def create_compression_tasks(): tasks = [] - for auth, ssl in AUTH_SSLS: - name = f"test-load-balancer-{auth}-{ssl}".lower() - tags = ["load-balancer", auth, ssl] - bootstrap_vars = dict(TOPOLOGY="sharded_cluster", AUTH=auth, SSL=ssl, LOAD_BALANCER="true") - bootstrap_func = FunctionCall(func="bootstrap mongo-orchestration", vars=bootstrap_vars) - balancer_func = FunctionCall(func="run load-balancer") - test_vars = dict(AUTH=auth, SSL=ssl, test_loadbalancer="true") + versions = get_versions_from("4.0") + # Test all server versions with min python. + for version in versions: + python = CPYTHONS[0] + tags = ["compression", version] + name = get_task_name("test-compression", python=python, version=version) + server_func = FunctionCall(func="run server", vars=dict(VERSION=version)) + test_func = FunctionCall(func="run tests") + tasks.append(EvgTask(name=name, tags=tags, commands=[server_func, test_func])) + + # Test latest with max python, with and without c exts. + version = "latest" + tags = ["compression", "latest"] + for c_ext in C_EXTS: + python = CPYTHONS[-1] + expansions = dict() + handle_c_ext(c_ext, expansions) + name = get_task_name("test-compression", python=python, version=version, **expansions) + server_func = FunctionCall(func="run server", vars=dict(VERSION=version)) + test_func = FunctionCall(func="run tests", vars=expansions) + tasks.append(EvgTask(name=name, tags=tags, commands=[server_func, test_func])) + + # Test on latest with pypy. + python = PYPYS[-1] + name = get_task_name("test-compression", python=python, version=version) + server_func = FunctionCall(func="run server", vars=dict(VERSION=version)) + test_func = FunctionCall(func="run tests") + tasks.append(EvgTask(name=name, tags=tags, commands=[server_func, test_func])) + return tasks + + +def create_kms_tasks(): + tasks = [] + for kms_type in ["gcp", "azure"]: + for success in [True, False]: + name = f"test-{kms_type}kms" + sub_test_name = kms_type + if not success: + name += "-fail" + sub_test_name += "-fail" + commands = [] + if not success: + commands.append(FunctionCall(func="run server")) + test_vars = dict(TEST_NAME="kms", SUB_TEST_NAME=sub_test_name) + test_func = FunctionCall(func="run tests", vars=test_vars) + commands.append(test_func) + tasks.append(EvgTask(name=name, commands=commands)) + return tasks + + +def create_aws_tasks(): + tasks = [] + aws_test_types = [ + "regular", + "assume-role", + "ec2", + "env-creds", + "session-creds", + "web-identity", + "ecs", + ] + for version in get_versions_from("4.4"): + base_name = f"test-auth-aws-{version}" + base_tags = ["auth-aws"] + server_vars = dict(AUTH_AWS="1", VERSION=version) + server_func = FunctionCall(func="run server", vars=server_vars) + assume_func = FunctionCall(func="assume ec2 role") + for test_type in aws_test_types: + tags = [*base_tags, f"auth-aws-{test_type}"] + name = f"{base_name}-{test_type}" + test_vars = dict(TEST_NAME="auth_aws", SUB_TEST_NAME=test_type) + test_func = FunctionCall(func="run tests", vars=test_vars) + funcs = [server_func, assume_func, test_func] + tasks.append(EvgTask(name=name, tags=tags, commands=funcs)) + + tags = [*base_tags, "auth-aws-web-identity"] + name = f"{base_name}-web-identity-session-name" + test_vars = dict( + TEST_NAME="auth_aws", SUB_TEST_NAME="web-identity", AWS_ROLE_SESSION_NAME="test" + ) test_func = FunctionCall(func="run tests", vars=test_vars) + funcs = [server_func, assume_func, test_func] + tasks.append(EvgTask(name=name, tags=tags, commands=funcs)) + + return tasks + + +def create_oidc_tasks(): + tasks = [] + for sub_test in ["default", "azure", "gcp", "eks", "aks", "gke"]: + vars = dict(TEST_NAME="auth_oidc", SUB_TEST_NAME=sub_test) + test_func = FunctionCall(func="run tests", vars=vars) + task_name = f"test-auth-oidc-{sub_test}" + tags = ["auth_oidc"] + if sub_test != "default": + tags.append("auth_oidc_remote") + tasks.append(EvgTask(name=task_name, tags=tags, commands=[test_func])) + return tasks + + +def create_mod_wsgi_tasks(): + tasks = [] + for test, topology in product(["standalone", "embedded-mode"], ["standalone", "replica_set"]): + if test == "standalone": + task_name = "mod-wsgi-" + else: + task_name = "mod-wsgi-embedded-mode-" + task_name += topology.replace("_", "-") + server_vars = dict(TOPOLOGY=topology) + server_func = FunctionCall(func="run server", vars=server_vars) + vars = dict(TEST_NAME="mod_wsgi", SUB_TEST_NAME=test.split("-")[0]) + test_func = FunctionCall(func="run tests", vars=vars) + tags = ["mod_wsgi"] + commands = [server_func, test_func] + tasks.append(EvgTask(name=task_name, tags=tags, commands=commands)) + return tasks + + +def _create_ocsp_tasks(algo, variant, server_type, base_task_name): + tasks = [] + file_name = f"{algo}-basic-tls-ocsp-{variant}.json" + + for version in get_versions_from("4.4"): + if version == "latest": + python = MIN_MAX_PYTHON[-1] + else: + python = MIN_MAX_PYTHON[0] + + vars = dict( + ORCHESTRATION_FILE=file_name, + OCSP_SERVER_TYPE=server_type, + TEST_NAME="ocsp", + PYTHON_VERSION=python, + VERSION=version, + ) + test_func = FunctionCall(func="run tests", vars=vars) + + tags = ["ocsp", f"ocsp-{algo}", version] + if "disableStapling" not in variant: + tags.append("ocsp-staple") + + task_name = get_task_name( + f"test-ocsp-{algo}-{base_task_name}", python=python, version=version + ) + tasks.append(EvgTask(name=task_name, tags=tags, commands=[test_func])) + return tasks + + +def create_aws_lambda_tasks(): + assume_func = FunctionCall(func="assume ec2 role") + vars = dict(TEST_NAME="aws_lambda") + test_func = FunctionCall(func="run tests", vars=vars) + task_name = "test-aws-lambda-deployed" + tags = ["aws_lambda"] + commands = [assume_func, test_func] + return [EvgTask(name=task_name, tags=tags, commands=commands)] + + +def create_search_index_tasks(): + assume_func = FunctionCall(func="assume ec2 role") + server_func = FunctionCall(func="run server", vars=dict(TEST_NAME="search_index")) + vars = dict(TEST_NAME="search_index") + test_func = FunctionCall(func="run tests", vars=vars) + task_name = "test-search-index-helpers" + tags = ["search_index"] + commands = [assume_func, server_func, test_func] + return [EvgTask(name=task_name, tags=tags, commands=commands)] + + +def create_atlas_connect_tasks(): + vars = dict(TEST_NAME="atlas_connect") + assume_func = FunctionCall(func="assume ec2 role") + test_func = FunctionCall(func="run tests", vars=vars) + task_name = "test-atlas-connect" + tags = ["atlas_connect"] + return [EvgTask(name=task_name, tags=tags, commands=[assume_func, test_func])] + + +def create_enterprise_auth_tasks(): + tasks = [] + for python in [*MIN_MAX_PYTHON, PYPYS[-1]]: + vars = dict(TEST_NAME="enterprise_auth", AUTH="auth", PYTHON_VERSION=python) + server_func = FunctionCall(func="run server", vars=vars) + assume_func = FunctionCall(func="assume ec2 role") + test_func = FunctionCall(func="run tests", vars=vars) + task_name = get_task_name("test-enterprise-auth", python=python) + tags = ["enterprise_auth"] + if python in PYPYS: + tags += ["pypy"] tasks.append( - EvgTask(name=name, tags=tags, commands=[bootstrap_func, balancer_func, test_func]) + EvgTask(name=task_name, tags=tags, commands=[server_func, assume_func, test_func]) ) return tasks +def create_perf_tasks(): + tasks = [] + for version, ssl, sync in product(["8.0"], ["ssl", "nossl"], ["sync", "async"]): + vars = dict(VERSION=f"v{version}-perf", SSL=ssl) + server_func = FunctionCall(func="run server", vars=vars) + vars = dict(TEST_NAME="perf", SUB_TEST_NAME=sync) + test_func = FunctionCall(func="run tests", vars=vars) + attach_func = FunctionCall(func="attach benchmark test results") + send_func = FunctionCall(func="send dashboard data") + task_name = f"perf-{version}-standalone" + if ssl == "ssl": + task_name += "-ssl" + if sync == "async": + task_name += "-async" + tags = ["perf"] + commands = [server_func, test_func, attach_func, send_func] + tasks.append(EvgTask(name=task_name, tags=tags, commands=commands)) + return tasks + + +def create_atlas_data_lake_tasks(): + tags = ["atlas_data_lake"] + tasks = [] + for c_ext in C_EXTS: + vars = dict(TEST_NAME="data_lake") + handle_c_ext(c_ext, vars) + test_func = FunctionCall(func="run tests", vars=vars) + task_name = f"test-atlas-data-lake-{c_ext}" + tasks.append(EvgTask(name=task_name, tags=tags, commands=[test_func])) + return tasks + + +def create_getdata_tasks(): + # Wildcard task. Do you need to find out what tools are available and where? + # Throw it here, and execute this task on all buildvariants + cmd = get_subprocess_exec(args=[".evergreen/scripts/run-getdata.sh"]) + return [EvgTask(name="getdata", commands=[cmd])] + + +def create_coverage_report_tasks(): + tags = ["coverage"] + task_name = "coverage-report" + # BUILD-3165: We can't use "*" (all tasks) and specify "variant". + # Instead list out all coverage tasks using tags. + # Run the coverage task even if some tasks fail. + # Run the coverage task even if some tasks are not scheduled in a patch build. + task_deps = [] + for name in [".standalone", ".replica_set", ".sharded_cluster"]: + task_deps.append( + EvgTaskDependency(name=name, variant=".coverage_tag", status="*", patch_optional=True) + ) + cmd = FunctionCall(func="download and merge coverage") + return [EvgTask(name=task_name, tags=tags, depends_on=task_deps, commands=[cmd])] + + +def create_import_time_tasks(): + name = "check-import-time" + tags = ["pr"] + args = [".evergreen/scripts/check-import-time.sh", "${revision}", "${github_commit}"] + cmd = get_subprocess_exec(args=args) + return [EvgTask(name=name, tags=tags, commands=[cmd])] + + +def create_backport_pr_tasks(): + name = "backport-pr" + args = [ + "${DRIVERS_TOOLS}/.evergreen/github_app/backport-pr.sh", + "mongodb", + "mongo-python-driver", + "${github_commit}", + ] + cmd = get_subprocess_exec(args=args) + return [EvgTask(name=name, commands=[cmd], allowed_requesters=["commit"])] + + +def create_ocsp_tasks(): + tasks = [] + tests = [ + ("disableStapling", "valid", "valid-cert-server-does-not-staple"), + ("disableStapling", "revoked", "invalid-cert-server-does-not-staple"), + ("disableStapling", "valid-delegate", "delegate-valid-cert-server-does-not-staple"), + ("disableStapling", "revoked-delegate", "delegate-invalid-cert-server-does-not-staple"), + ("disableStapling", "no-responder", "soft-fail"), + ("mustStaple", "valid", "valid-cert-server-staples"), + ("mustStaple", "revoked", "invalid-cert-server-staples"), + ("mustStaple", "valid-delegate", "delegate-valid-cert-server-staples"), + ("mustStaple", "revoked-delegate", "delegate-invalid-cert-server-staples"), + ( + "mustStaple-disableStapling", + "revoked", + "malicious-invalid-cert-mustStaple-server-does-not-staple", + ), + ( + "mustStaple-disableStapling", + "revoked-delegate", + "delegate-malicious-invalid-cert-mustStaple-server-does-not-staple", + ), + ( + "mustStaple-disableStapling", + "no-responder", + "malicious-no-responder-mustStaple-server-does-not-staple", + ), + ] + for algo in ["ecdsa", "rsa"]: + for variant, server_type, base_task_name in tests: + new_tasks = _create_ocsp_tasks(algo, variant, server_type, base_task_name) + tasks.extend(new_tasks) + + return tasks + + +def create_mockupdb_tasks(): + test_func = FunctionCall(func="run tests", vars=dict(TEST_NAME="mockupdb")) + task_name = "test-mockupdb" + tags = ["mockupdb"] + return [EvgTask(name=task_name, tags=tags, commands=[test_func])] + + +def create_doctest_tasks(): + server_func = FunctionCall(func="run server") + test_func = FunctionCall(func="run just script", vars=dict(JUSTFILE_TARGET="docs-test")) + task_name = "test-doctests" + tags = ["doctests"] + return [EvgTask(name=task_name, tags=tags, commands=[server_func, test_func])] + + +def create_no_server_tasks(): + test_func = FunctionCall(func="run tests") + task_name = "test-no-server" + tags = ["no-server"] + return [EvgTask(name=task_name, tags=tags, commands=[test_func])] + + +def create_free_threading_tasks(): + vars = dict(VERSION="8.0", TOPOLOGY="replica_set") + server_func = FunctionCall(func="run server", vars=vars) + test_func = FunctionCall(func="run tests") + task_name = "test-free-threading" + tags = ["free-threading"] + return [EvgTask(name=task_name, tags=tags, commands=[server_func, test_func])] + + +def create_serverless_tasks(): + vars = dict(TEST_NAME="serverless", AUTH="auth", SSL="ssl") + test_func = FunctionCall(func="run tests", vars=vars) + tags = ["serverless"] + task_name = "test-serverless" + return [EvgTask(name=task_name, tags=tags, commands=[test_func])] + + +############## +# Functions +############## + + +def create_upload_coverage_func(): + # Upload the coverage report for all tasks in a single build to the same directory. + remote_file = ( + "coverage/${revision}/${version_id}/coverage/coverage.${build_variant}.${task_name}" + ) + display_name = "Raw Coverage Report" + cmd = get_s3_put( + local_file="src/.coverage", + remote_file=remote_file, + display_name=display_name, + content_type="text/html", + ) + return "upload coverage", [get_assume_role(), cmd] + + +def create_download_and_merge_coverage_func(): + include_expansions = ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"] + args = [ + ".evergreen/scripts/download-and-merge-coverage.sh", + "${bucket_name}", + "${revision}", + "${version_id}", + ] + merge_cmd = get_subprocess_exec( + silent=True, include_expansions_in_env=include_expansions, args=args + ) + combine_cmd = get_subprocess_exec(args=[".evergreen/combine-coverage.sh"]) + # Upload the resulting html coverage report. + args = [ + ".evergreen/scripts/upload-coverage-report.sh", + "${bucket_name}", + "${revision}", + "${version_id}", + ] + upload_cmd = get_subprocess_exec( + silent=True, include_expansions_in_env=include_expansions, args=args + ) + display_name = "Coverage Report HTML" + remote_file = "coverage/${revision}/${version_id}/htmlcov/index.html" + put_cmd = get_s3_put( + local_file="src/htmlcov/index.html", + remote_file=remote_file, + display_name=display_name, + content_type="text/html", + ) + cmds = [get_assume_role(), merge_cmd, combine_cmd, upload_cmd, put_cmd] + return "download and merge coverage", cmds + + +def create_upload_mo_artifacts_func(): + include = ["./**.core", "./**.mdmp"] # Windows: minidumps + archive_cmd = archive_targz_pack(target="mongo-coredumps.tgz", source_dir="./", include=include) + display_name = "Core Dumps - Execution" + remote_file = "${build_variant}/${revision}/${version_id}/${build_id}/coredumps/${task_id}-${execution}-mongodb-coredumps.tar.gz" + s3_dumps = get_s3_put( + local_file="mongo-coredumps.tgz", remote_file=remote_file, display_name=display_name + ) + display_name = "drivers-tools-logs.tar.gz" + remote_file = "${build_variant}/${revision}/${version_id}/${build_id}/logs/${task_id}-${execution}-drivers-tools-logs.tar.gz" + s3_logs = get_s3_put( + local_file="${DRIVERS_TOOLS}/.evergreen/test_logs.tar.gz", + remote_file=remote_file, + display_name=display_name, + ) + cmds = [get_assume_role(), archive_cmd, s3_dumps, s3_logs] + return "upload mo artifacts", cmds + + ################## # Generate Config ################## @@ -856,7 +1316,7 @@ def write_variants_to_file(): with target.open("w") as fid: fid.write("buildvariants:\n") - for name, func in getmembers(mod, isfunction): + for name, func in sorted(getmembers(mod, isfunction)): if not name.endswith("_variants"): continue if not name.startswith("create_"): @@ -886,8 +1346,8 @@ def write_tasks_to_file(): with target.open("w") as fid: fid.write("tasks:\n") - for name, func in getmembers(mod, isfunction): - if not name.endswith("_tasks"): + for name, func in sorted(getmembers(mod, isfunction)): + if name.startswith("_") or not name.endswith("_tasks"): continue if not name.startswith("create_"): raise ValueError("Task creators must start with create_") @@ -907,5 +1367,40 @@ def write_tasks_to_file(): fid.write(f"{line}\n") +def write_functions_to_file(): + mod = sys.modules[__name__] + here = Path(__file__).absolute().parent + target = here.parent / "generated_configs" / "functions.yml" + if target.exists(): + target.unlink() + with target.open("w") as fid: + fid.write("functions:\n") + + functions = dict() + for name, func in sorted(getmembers(mod, isfunction)): + if name.startswith("_") or not name.endswith("_func"): + continue + if not name.startswith("create_"): + raise ValueError("Function creators must start with create_") + title = name.replace("create_", "").replace("_func", "").replace("_", " ").capitalize() + func_name, cmds = func() + functions = dict() + functions[func_name] = cmds + project = EvgProject(functions=functions, tasks=None, buildvariants=None) + out = ShrubService.generate_yaml(project).splitlines() + with target.open("a") as fid: + fid.write(f" # {title}\n") + for line in out[1:]: + fid.write(f"{line}\n") + fid.write("\n") + + # Remove extra trailing newline: + data = target.read_text().splitlines() + with target.open("w") as fid: + for line in data[:-1]: + fid.write(f"{line}\n") + + write_variants_to_file() write_tasks_to_file() +write_functions_to_file() diff --git a/.evergreen/scripts/init-test-results.sh b/.evergreen/scripts/init-test-results.sh deleted file mode 100755 index 666ac60620..0000000000 --- a/.evergreen/scripts/init-test-results.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/bash - -set +x -. src/.evergreen/scripts/env.sh -echo '{"results": [{ "status": "FAIL", "test_file": "Build", "log_raw": "No test-results.json found was created" } ]}' >$PROJECT_DIRECTORY/test-results.json diff --git a/.evergreen/scripts/install-dependencies.sh b/.evergreen/scripts/install-dependencies.sh index 39b77199bb..780d250a2b 100755 --- a/.evergreen/scripts/install-dependencies.sh +++ b/.evergreen/scripts/install-dependencies.sh @@ -1,30 +1,45 @@ #!/bin/bash - +# Install the dependencies needed for an evergreen run. set -eu -# On Evergreen jobs, "CI" will be set, and we don't want to write to $HOME. -if [ "${CI:-}" == "true" ]; then - _BIN_DIR=${DRIVERS_TOOLS_BINARIES:-} -else - _BIN_DIR=$HOME/.local/bin +HERE=$(dirname ${BASH_SOURCE:-$0}) +pushd "$(dirname "$(dirname $HERE)")" > /dev/null + +# Source the env files to pick up common variables. +if [ -f $HERE/env.sh ]; then + . $HERE/env.sh fi +_BIN_DIR=${PYMONGO_BIN_DIR:-$HOME/.local/bin} +export PATH="$PATH:${_BIN_DIR}" # Helper function to pip install a dependency using a temporary python env. function _pip_install() { _HERE=$(dirname ${BASH_SOURCE:-$0}) . $_HERE/../utils.sh _VENV_PATH=$(mktemp -d) + if [ "Windows_NT" = "${OS:-}" ]; then + _VENV_PATH=$(cygpath -m $_VENV_PATH) + fi echo "Installing $2 using pip..." createvirtualenv "$(find_python3)" $_VENV_PATH python -m pip install $1 - ln -s "$(which $2)" $_BIN_DIR/$2 + _suffix="" + if [ "Windows_NT" = "${OS:-}" ]; then + _suffix=".exe" + fi + ln -s "$(which $2)" $_BIN_DIR/${2}${_suffix} + # uv also comes with a uvx binary. + if [ $2 == "uv" ]; then + ln -s "$(which uvx)" $_BIN_DIR/uvx${_suffix} + fi + echo "Installed to ${_BIN_DIR}" echo "Installing $2 using pip... done." } # Ensure just is installed. -if ! command -v just 2>/dev/null; then +if ! command -v just >/dev/null 2>&1; then # On most systems we can install directly. _TARGET="" if [ "Windows_NT" = "${OS:-}" ]; then @@ -35,21 +50,20 @@ if ! command -v just 2>/dev/null; then curl --proto '=https' --tlsv1.2 -sSf https://just.systems/install.sh | bash -s -- $_TARGET --to "$_BIN_DIR" || { _pip_install rust-just just } - if ! command -v just 2>/dev/null; then - export PATH="$PATH:$_BIN_DIR" - fi echo "Installing just... done." fi # Install uv. -if ! command -v uv 2>/dev/null; then +if ! command -v uv >/dev/null 2>&1; then echo "Installing uv..." # On most systems we can install directly. curl -LsSf https://astral.sh/uv/install.sh | env UV_INSTALL_DIR="$_BIN_DIR" INSTALLER_NO_MODIFY_PATH=1 sh || { _pip_install uv uv } - if ! command -v uv 2>/dev/null; then - export PATH="$PATH:$_BIN_DIR" + if [ "Windows_NT" = "${OS:-}" ]; then + chmod +x "$(cygpath -u $_BIN_DIR)/uv.exe" fi echo "Installing uv... done." fi + +popd > /dev/null diff --git a/.evergreen/scripts/kms_tester.py b/.evergreen/scripts/kms_tester.py new file mode 100644 index 0000000000..40fd65919d --- /dev/null +++ b/.evergreen/scripts/kms_tester.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +import os + +from utils import ( + DRIVERS_TOOLS, + LOGGER, + TMP_DRIVER_FILE, + create_archive, + read_env, + run_command, + write_env, +) + +DIRS = dict( + gcp=f"{DRIVERS_TOOLS}/.evergreen/csfle/gcpkms", + azure=f"{DRIVERS_TOOLS}/.evergreen/csfle/azurekms", +) + + +def _setup_azure_vm(base_env: dict[str, str]) -> None: + LOGGER.info("Setting up Azure VM...") + azure_dir = DIRS["azure"] + env = base_env.copy() + env["AZUREKMS_SRC"] = TMP_DRIVER_FILE + env["AZUREKMS_DST"] = "~/" + run_command(f"{azure_dir}/copy-file.sh", env=env) + + env = base_env.copy() + env["AZUREKMS_CMD"] = "tar xf mongo-python-driver.tgz" + run_command(f"{azure_dir}/run-command.sh", env=env) + + env["AZUREKMS_CMD"] = "bash .evergreen/just.sh setup-tests kms azure-remote" + run_command(f"{azure_dir}/run-command.sh", env=env) + LOGGER.info("Setting up Azure VM... done.") + + +def _setup_gcp_vm(base_env: dict[str, str]) -> None: + LOGGER.info("Setting up GCP VM...") + gcp_dir = DIRS["gcp"] + env = base_env.copy() + env["GCPKMS_SRC"] = TMP_DRIVER_FILE + env["GCPKMS_DST"] = f"{env['GCPKMS_INSTANCENAME']}:" + run_command(f"{gcp_dir}/copy-file.sh", env=env) + + env = base_env.copy() + env["GCPKMS_CMD"] = "tar xf mongo-python-driver.tgz" + run_command(f"{gcp_dir}/run-command.sh", env=env) + + env["GCPKMS_CMD"] = "bash ./.evergreen/just.sh setup-tests kms gcp-remote" + run_command(f"{gcp_dir}/run-command.sh", env=env) + LOGGER.info("Setting up GCP VM...") + + +def _load_kms_config(sub_test_target: str) -> dict[str, str]: + target_dir = DIRS[sub_test_target] + config = read_env(f"{target_dir}/secrets-export.sh") + base_env = os.environ.copy() + for key, value in config.items(): + base_env[key] = str(value) + return base_env + + +def setup_kms(sub_test_name: str) -> None: + if "-" in sub_test_name: + sub_test_target, sub_test_type = sub_test_name.split("-") + else: + sub_test_target = sub_test_name + sub_test_type = "" + + assert sub_test_target in ["azure", "gcp"], sub_test_target + assert sub_test_type in ["", "remote", "fail"], sub_test_type + success = sub_test_type != "fail" + kms_dir = DIRS[sub_test_target] + + if sub_test_target == "azure": + write_env("TEST_FLE_AZURE_AUTO") + else: + write_env("TEST_FLE_GCP_AUTO") + + write_env("SUCCESS", success) + + # For remote tests, there is no further work required. + if sub_test_type == "remote": + return + + if sub_test_target == "azure": + run_command("./setup-secrets.sh", cwd=kms_dir) + + if success: + create_archive() + if sub_test_target == "azure": + os.environ["AZUREKMS_VMNAME_PREFIX"] = "PYTHON_DRIVER" + + run_command("./setup.sh", cwd=kms_dir) + base_env = _load_kms_config(sub_test_target) + + if sub_test_target == "azure": + _setup_azure_vm(base_env) + else: + _setup_gcp_vm(base_env) + + if sub_test_target == "azure": + config = read_env(f"{kms_dir}/secrets-export.sh") + if success: + write_env("AZUREKMS_VMNAME", config["AZUREKMS_VMNAME"]) + + write_env("KEY_NAME", config["AZUREKMS_KEYNAME"]) + write_env("KEY_VAULT_ENDPOINT", config["AZUREKMS_KEYVAULTENDPOINT"]) + + +def test_kms_send_to_remote(sub_test_name: str) -> None: + env = _load_kms_config(sub_test_name) + if sub_test_name == "azure": + key_name = os.environ["KEY_NAME"] + key_vault_endpoint = os.environ["KEY_VAULT_ENDPOINT"] + env[ + "AZUREKMS_CMD" + ] = f'KEY_NAME="{key_name}" KEY_VAULT_ENDPOINT="{key_vault_endpoint}" bash ./.evergreen/just.sh run-tests' + else: + env["GCPKMS_CMD"] = "./.evergreen/just.sh run-tests" + cmd = f"{DIRS[sub_test_name]}/run-command.sh" + run_command(cmd, env=env) + + +def teardown_kms(sub_test_name: str) -> None: + run_command(f"{DIRS[sub_test_name]}/teardown.sh") + + +if __name__ == "__main__": + setup_kms() diff --git a/.evergreen/scripts/make-files-executable.sh b/.evergreen/scripts/make-files-executable.sh deleted file mode 100755 index 806be7c599..0000000000 --- a/.evergreen/scripts/make-files-executable.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash - -set +x -. src/.evergreen/scripts/env.sh -# shellcheck disable=SC2044 -for i in $(find "$DRIVERS_TOOLS"/.evergreen "$PROJECT_DIRECTORY"/.evergreen -name \*.sh); do - chmod +x "$i" -done diff --git a/.evergreen/scripts/mod_wsgi_tester.py b/.evergreen/scripts/mod_wsgi_tester.py new file mode 100644 index 0000000000..5968849068 --- /dev/null +++ b/.evergreen/scripts/mod_wsgi_tester.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +import os +import sys +import time +import urllib.error +import urllib.request +from pathlib import Path +from shutil import which + +from utils import LOGGER, ROOT, run_command, write_env + + +def make_request(url, timeout=10): + for _ in range(int(timeout)): + try: + urllib.request.urlopen(url) # noqa: S310 + return + except urllib.error.HTTPError: + pass + time.sleep(1) + raise TimeoutError(f"Failed to access {url}") + + +def setup_mod_wsgi(sub_test_name: str) -> None: + env = os.environ.copy() + if sub_test_name == "embedded": + env["MOD_WSGI_CONF"] = "mod_wsgi_test_embedded.conf" + elif sub_test_name == "standalone": + env["MOD_WSGI_CONF"] = "mod_wsgi_test.conf" + else: + raise ValueError("mod_wsgi sub test must be either 'standalone' or 'embedded'") + write_env("MOD_WSGI_CONF", env["MOD_WSGI_CONF"]) + apache = which("apache2") + if not apache and Path("/usr/lib/apache2/mpm-prefork/apache2").exists(): + apache = "/usr/lib/apache2/mpm-prefork/apache2" + if apache: + apache_config = "apache24ubuntu161404.conf" + else: + apache = which("httpd") + if not apache: + raise ValueError("Could not find apache2 or httpd") + apache_config = "apache22amazon.conf" + python_version = ".".join(str(val) for val in sys.version_info[:2]) + mod_wsgi_version = 4 + so_file = f"/opt/python/mod_wsgi/python_version/{python_version}/mod_wsgi_version/{mod_wsgi_version}/mod_wsgi.so" + write_env("MOD_WSGI_SO", so_file) + env["MOD_WSGI_SO"] = so_file + env["PYTHONHOME"] = f"/opt/python/{python_version}" + env["PROJECT_DIRECTORY"] = project_directory = str(ROOT) + write_env("APACHE_BINARY", apache) + write_env("APACHE_CONFIG", apache_config) + uri1 = f"http://localhost:8080/interpreter1{project_directory}" + write_env("TEST_URI1", uri1) + uri2 = f"http://localhost:8080/interpreter2{project_directory}" + write_env("TEST_URI2", uri2) + run_command(f"{apache} -k start -f {ROOT}/test/mod_wsgi_test/{apache_config}", env=env) + + # Wait for the endpoints to be available. + try: + make_request(uri1, 10) + make_request(uri2, 10) + except Exception as e: + LOGGER.error(Path("error_log").read_text()) + raise e + + +def test_mod_wsgi() -> None: + sys.path.insert(0, ROOT) + from test.mod_wsgi_test.test_client import main, parse_args + + uri1 = os.environ["TEST_URI1"] + uri2 = os.environ["TEST_URI2"] + args = f"-n 25000 -t 100 parallel {uri1} {uri2}" + try: + main(*parse_args(args.split())) + + args = f"-n 25000 serial {uri1} {uri2}" + main(*parse_args(args.split())) + except Exception as e: + LOGGER.error(Path("error_log").read_text()) + raise e + + +def teardown_mod_wsgi() -> None: + apache = os.environ["APACHE_BINARY"] + apache_config = os.environ["APACHE_CONFIG"] + + run_command(f"{apache} -k stop -f {ROOT}/test/mod_wsgi_test/{apache_config}") + + +if __name__ == "__main__": + setup_mod_wsgi() diff --git a/.evergreen/scripts/oidc_tester.py b/.evergreen/scripts/oidc_tester.py new file mode 100644 index 0000000000..fd702cf1d1 --- /dev/null +++ b/.evergreen/scripts/oidc_tester.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +import os + +from utils import DRIVERS_TOOLS, TMP_DRIVER_FILE, create_archive, read_env, run_command, write_env + +K8S_NAMES = ["aks", "gke", "eks"] +K8S_REMOTE_NAMES = [f"{n}-remote" for n in K8S_NAMES] + + +def _get_target_dir(sub_test_name: str) -> str: + if sub_test_name == "default": + target_dir = "auth_oidc" + elif sub_test_name.startswith("azure"): + target_dir = "auth_oidc/azure" + elif sub_test_name.startswith("gcp"): + target_dir = "auth_oidc/gcp" + elif sub_test_name in K8S_NAMES + K8S_REMOTE_NAMES: + target_dir = "auth_oidc/k8s" + else: + raise ValueError(f"Invalid sub test name '{sub_test_name}'") + return f"{DRIVERS_TOOLS}/.evergreen/{target_dir}" + + +def setup_oidc(sub_test_name: str) -> dict[str, str] | None: + target_dir = _get_target_dir(sub_test_name) + env = os.environ.copy() + + if sub_test_name == "eks" and "AWS_ACCESS_KEY_ID" in os.environ: + # Store AWS creds for kubectl access. + for key in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"]: + if key in os.environ: + write_env(key, os.environ[key]) + + if sub_test_name == "azure": + env["AZUREOIDC_VMNAME_PREFIX"] = "PYTHON_DRIVER" + if "-remote" not in sub_test_name: + run_command(f"bash {target_dir}/setup.sh", env=env) + if sub_test_name in K8S_NAMES: + run_command(f"bash {target_dir}/setup-pod.sh {sub_test_name}") + run_command(f"bash {target_dir}/run-self-test.sh") + return None + + source_file = None + if sub_test_name == "default": + source_file = f"{target_dir}/secrets-export.sh" + elif sub_test_name in ["azure-remote", "gcp-remote"]: + source_file = "./secrets-export.sh" + if sub_test_name in K8S_REMOTE_NAMES: + return os.environ.copy() + if source_file is None: + return None + + config = read_env(source_file) + write_env("MONGODB_URI_SINGLE", config["MONGODB_URI_SINGLE"]) + write_env("MONGODB_URI", config["MONGODB_URI"]) + write_env("DB_IP", config["MONGODB_URI"]) + + if sub_test_name == "default": + write_env("OIDC_TOKEN_FILE", config["OIDC_TOKEN_FILE"]) + write_env("OIDC_TOKEN_DIR", config["OIDC_TOKEN_DIR"]) + if "OIDC_DOMAIN" in config: + write_env("OIDC_DOMAIN", config["OIDC_DOMAIN"]) + elif sub_test_name == "azure-remote": + write_env("AZUREOIDC_RESOURCE", config["AZUREOIDC_RESOURCE"]) + elif sub_test_name == "gcp-remote": + write_env("GCPOIDC_AUDIENCE", config["GCPOIDC_AUDIENCE"]) + return config + + +def test_oidc_send_to_remote(sub_test_name: str) -> None: + env = os.environ.copy() + target_dir = _get_target_dir(sub_test_name) + create_archive() + if sub_test_name in ["azure", "gcp"]: + upper_name = sub_test_name.upper() + env[f"{upper_name}OIDC_DRIVERS_TAR_FILE"] = TMP_DRIVER_FILE + env[ + f"{upper_name}OIDC_TEST_CMD" + ] = f"OIDC_ENV={sub_test_name} ./.evergreen/run-mongodb-oidc-test.sh" + elif sub_test_name in K8S_NAMES: + env["K8S_DRIVERS_TAR_FILE"] = TMP_DRIVER_FILE + env["K8S_TEST_CMD"] = "OIDC_ENV=k8s ./.evergreen/run-mongodb-oidc-test.sh" + run_command(f"bash {target_dir}/run-driver-test.sh", env=env) + + +def teardown_oidc(sub_test_name: str) -> None: + target_dir = _get_target_dir(sub_test_name) + # For k8s, make sure an error while tearing down the pod doesn't prevent + # the Altas server teardown. + error = None + if sub_test_name in K8S_NAMES: + try: + run_command(f"bash {target_dir}/teardown-pod.sh") + except Exception as e: + error = e + run_command(f"bash {target_dir}/teardown.sh") + if error: + raise error diff --git a/.evergreen/scripts/prepare-resources.sh b/.evergreen/scripts/prepare-resources.sh deleted file mode 100755 index da869e7055..0000000000 --- a/.evergreen/scripts/prepare-resources.sh +++ /dev/null @@ -1,28 +0,0 @@ -#!/bin/bash -set -eu - -HERE=$(dirname ${BASH_SOURCE:-$0}) -pushd $HERE -. env.sh - -rm -rf $DRIVERS_TOOLS -git clone https://github.com/mongodb-labs/drivers-evergreen-tools.git $DRIVERS_TOOLS -echo "{ \"releases\": { \"default\": \"$MONGODB_BINARIES\" }}" >$MONGO_ORCHESTRATION_HOME/orchestration.config - -popd - -# Copy PyMongo's test certificates over driver-evergreen-tools' -cp ${PROJECT_DIRECTORY}/test/certificates/* ${DRIVERS_TOOLS}/.evergreen/x509gen/ - -# Replace MongoOrchestration's client certificate. -cp ${PROJECT_DIRECTORY}/test/certificates/client.pem ${MONGO_ORCHESTRATION_HOME}/lib/client.pem - -if [ -w /etc/hosts ]; then - SUDO="" -else - SUDO="sudo" -fi - -# Add 'server' and 'hostname_not_in_cert' as a hostnames -echo "127.0.0.1 server" | $SUDO tee -a /etc/hosts -echo "127.0.0.1 hostname_not_in_cert" | $SUDO tee -a /etc/hosts diff --git a/.evergreen/scripts/run-atlas-tests.sh b/.evergreen/scripts/run-atlas-tests.sh deleted file mode 100755 index 30b8d5a615..0000000000 --- a/.evergreen/scripts/run-atlas-tests.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/bash - -# Disable xtrace for security reasons (just in case it was accidentally set). -set +x -set -o errexit -bash "${DRIVERS_TOOLS}"/.evergreen/auth_aws/setup_secrets.sh drivers/atlas_connect -TEST_ATLAS=1 bash "${PROJECT_DIRECTORY}"/.evergreen/just.sh test-eg diff --git a/.evergreen/scripts/run-aws-ecs-auth-test.sh b/.evergreen/scripts/run-aws-ecs-auth-test.sh deleted file mode 100755 index 787e0a710b..0000000000 --- a/.evergreen/scripts/run-aws-ecs-auth-test.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -# shellcheck disable=SC2154 -if [ "${skip_ECS_auth_test}" = "true" ]; then - echo "This platform does not support the ECS auth test, skipping..." - exit 0 -fi -set -ex -cd "$DRIVERS_TOOLS"/.evergreen/auth_aws -. ./activate-authawsvenv.sh -. aws_setup.sh ecs -export MONGODB_BINARIES="$MONGODB_BINARIES" -export PROJECT_DIRECTORY="$PROJECT_DIRECTORY" -python aws_tester.py ecs -cd - diff --git a/.evergreen/scripts/run-direct-tests.sh b/.evergreen/scripts/run-direct-tests.sh deleted file mode 100755 index a00235311c..0000000000 --- a/.evergreen/scripts/run-direct-tests.sh +++ /dev/null @@ -1,10 +0,0 @@ -#!/bin/bash -set -x -. .evergreen/utils.sh - -. .evergreen/scripts/env.sh -createvirtualenv "$PYTHON_BINARY" .venv - -export PYMONGO_C_EXT_MUST_BUILD=1 -pip install -e ".[test]" -pytest -v diff --git a/.evergreen/scripts/run-doctests.sh b/.evergreen/scripts/run-doctests.sh deleted file mode 100755 index 5950e2c107..0000000000 --- a/.evergreen/scripts/run-doctests.sh +++ /dev/null @@ -1,4 +0,0 @@ -#!/bin/bash - -set -o xtrace -PYTHON_BINARY=${PYTHON_BINARY} bash "${PROJECT_DIRECTORY}"/.evergreen/just.sh docs-test diff --git a/.evergreen/scripts/run-enterprise-auth-tests.sh b/.evergreen/scripts/run-enterprise-auth-tests.sh deleted file mode 100755 index e015a34ca4..0000000000 --- a/.evergreen/scripts/run-enterprise-auth-tests.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -set -eu - -# Disable xtrace for security reasons (just in case it was accidentally set). -set +x -# Use the default python to bootstrap secrets. -bash "${DRIVERS_TOOLS}"/.evergreen/secrets_handling/setup-secrets.sh drivers/enterprise_auth -TEST_ENTERPRISE_AUTH=1 AUTH=auth bash "${PROJECT_DIRECTORY}"/.evergreen/just.sh test-eg diff --git a/.evergreen/scripts/run-gcpkms-fail-test.sh b/.evergreen/scripts/run-gcpkms-fail-test.sh deleted file mode 100755 index 594a2984fa..0000000000 --- a/.evergreen/scripts/run-gcpkms-fail-test.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/bash - -. .evergreen/scripts/env.sh -export PYTHON_BINARY=/opt/mongodbtoolchain/v4/bin/python3 -export LIBMONGOCRYPT_URL=https://s3.amazonaws.com/mciuploads/libmongocrypt/debian11/master/latest/libmongocrypt.tar.gz -SKIP_SERVERS=1 bash ./.evergreen/setup-encryption.sh -SUCCESS=false TEST_FLE_GCP_AUTO=1 ./.evergreen/just.sh test-eg diff --git a/.evergreen/scripts/run-getdata.sh b/.evergreen/scripts/run-getdata.sh index b2d6ecb476..9435a5fcc3 100755 --- a/.evergreen/scripts/run-getdata.sh +++ b/.evergreen/scripts/run-getdata.sh @@ -1,11 +1,14 @@ #!/bin/bash +# Get the debug data for an evergreen task. +set -eu -set -o xtrace -. ${DRIVERS_TOOLS}/.evergreen/download-mongodb.sh || true +. ${DRIVERS_TOOLS}/.evergreen/get-distro.sh || true get_distro || true echo $DISTRO echo $MARCH echo $OS + +set -x uname -a || true ls /etc/*release* || true cc --version || true @@ -20,3 +23,4 @@ ls -la /usr/local/Cellar/ || true scan-build --version || true genhtml --version || true valgrind --version || true +set +x diff --git a/.evergreen/scripts/run-load-balancer.sh b/.evergreen/scripts/run-load-balancer.sh deleted file mode 100755 index 7d431777e5..0000000000 --- a/.evergreen/scripts/run-load-balancer.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/bash - -MONGODB_URI=${MONGODB_URI} bash "${DRIVERS_TOOLS}"/.evergreen/run-load-balancer.sh start diff --git a/.evergreen/scripts/run-mockupdb-tests.sh b/.evergreen/scripts/run-mockupdb-tests.sh deleted file mode 100755 index 32594f05d3..0000000000 --- a/.evergreen/scripts/run-mockupdb-tests.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/bash - -set -o xtrace -export PYTHON_BINARY=${PYTHON_BINARY} -bash "${PROJECT_DIRECTORY}"/.evergreen/just.sh test-mockupdb diff --git a/.evergreen/scripts/run-mod-wsgi-tests.sh b/.evergreen/scripts/run-mod-wsgi-tests.sh deleted file mode 100755 index 607458b8c6..0000000000 --- a/.evergreen/scripts/run-mod-wsgi-tests.sh +++ /dev/null @@ -1,52 +0,0 @@ -#!/bin/bash -set -o xtrace -set -o errexit - -APACHE=$(command -v apache2 || command -v /usr/lib/apache2/mpm-prefork/apache2) || true -if [ -n "$APACHE" ]; then - APACHE_CONFIG=apache24ubuntu161404.conf -else - APACHE=$(command -v httpd) || true - if [ -z "$APACHE" ]; then - echo "Could not find apache2 binary" - exit 1 - else - APACHE_CONFIG=apache22amazon.conf - fi -fi - - -PYTHON_VERSION=$(${PYTHON_BINARY} -c "import sys; sys.stdout.write('.'.join(str(val) for val in sys.version_info[:2]))") - -# Ensure the C extensions are installed. -${PYTHON_BINARY} -m venv --system-site-packages .venv -source .venv/bin/activate -pip install -U pip -python -m pip install -e . - -export MOD_WSGI_SO=/opt/python/mod_wsgi/python_version/$PYTHON_VERSION/mod_wsgi_version/$MOD_WSGI_VERSION/mod_wsgi.so -export PYTHONHOME=/opt/python/$PYTHON_VERSION -# If MOD_WSGI_EMBEDDED is set use the default embedded mode behavior instead -# of daemon mode (WSGIDaemonProcess). -if [ -n "${MOD_WSGI_EMBEDDED:-}" ]; then - export MOD_WSGI_CONF=mod_wsgi_test_embedded.conf -else - export MOD_WSGI_CONF=mod_wsgi_test.conf -fi - -cd .. -$APACHE -k start -f ${PROJECT_DIRECTORY}/test/mod_wsgi_test/${APACHE_CONFIG} -trap '$APACHE -k stop -f ${PROJECT_DIRECTORY}/test/mod_wsgi_test/${APACHE_CONFIG}' EXIT HUP - -wget -t 1 -T 10 -O - "http://localhost:8080/interpreter1${PROJECT_DIRECTORY}" || (cat error_log && exit 1) -wget -t 1 -T 10 -O - "http://localhost:8080/interpreter2${PROJECT_DIRECTORY}" || (cat error_log && exit 1) - -python ${PROJECT_DIRECTORY}/test/mod_wsgi_test/test_client.py -n 25000 -t 100 parallel \ - http://localhost:8080/interpreter1${PROJECT_DIRECTORY} http://localhost:8080/interpreter2${PROJECT_DIRECTORY} || \ - (tail -n 100 error_log && exit 1) - -python ${PROJECT_DIRECTORY}/test/mod_wsgi_test/test_client.py -n 25000 serial \ - http://localhost:8080/interpreter1${PROJECT_DIRECTORY} http://localhost:8080/interpreter2${PROJECT_DIRECTORY} || \ - (tail -n 100 error_log && exit 1) - -rm -rf .venv diff --git a/.evergreen/scripts/run-mongodb-aws-test.sh b/.evergreen/scripts/run-mongodb-aws-test.sh deleted file mode 100755 index 88c3236b3f..0000000000 --- a/.evergreen/scripts/run-mongodb-aws-test.sh +++ /dev/null @@ -1,33 +0,0 @@ -#!/bin/bash - -set -o xtrace -set -o errexit # Exit the script with error if any of the commands fail - -############################################ -# Main Program # -############################################ - -# Supported/used environment variables: -# MONGODB_URI Set the URI, including an optional username/password to use -# to connect to the server via MONGODB-AWS authentication -# mechanism. -# PYTHON_BINARY The Python version to use. - -# shellcheck disable=SC2154 -if [ "${skip_EC2_auth_test:-}" = "true" ] && { [ "$1" = "ec2" ] || [ "$1" = "web-identity" ]; }; then - echo "This platform does not support the EC2 auth test, skipping..." - exit 0 -fi - -echo "Running MONGODB-AWS authentication tests for $1" - -# Handle credentials and environment setup. -. "$DRIVERS_TOOLS"/.evergreen/auth_aws/aws_setup.sh "$1" - -# show test output -set -x - -export TEST_AUTH_AWS=1 -export AUTH="auth" -export SET_XTRACE_ON=1 -bash ./.evergreen/just.sh test-eg diff --git a/.evergreen/scripts/run-ocsp-test.sh b/.evergreen/scripts/run-ocsp-test.sh deleted file mode 100755 index 328bd2f203..0000000000 --- a/.evergreen/scripts/run-ocsp-test.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash - -TEST_OCSP=1 \ -PYTHON_BINARY="${PYTHON_BINARY}" \ -CA_FILE="${DRIVERS_TOOLS}/.evergreen/ocsp/${OCSP_ALGORITHM}/ca.pem" \ -OCSP_TLS_SHOULD_SUCCEED="${OCSP_TLS_SHOULD_SUCCEED}" \ -bash "${PROJECT_DIRECTORY}"/.evergreen/just.sh test-eg -bash "${DRIVERS_TOOLS}"/.evergreen/ocsp/teardown.sh diff --git a/.evergreen/scripts/run-perf-tests.sh b/.evergreen/scripts/run-perf-tests.sh deleted file mode 100755 index 69a369fee1..0000000000 --- a/.evergreen/scripts/run-perf-tests.sh +++ /dev/null @@ -1,4 +0,0 @@ -#!/bin/bash - -PROJECT_DIRECTORY=${PROJECT_DIRECTORY} -bash "${PROJECT_DIRECTORY}"/.evergreen/run-perf-tests.sh diff --git a/.evergreen/scripts/run-server.sh b/.evergreen/scripts/run-server.sh new file mode 100755 index 0000000000..298eedcd3e --- /dev/null +++ b/.evergreen/scripts/run-server.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +set -eu + +HERE=$(dirname ${BASH_SOURCE:-$0}) + +# Try to source the env file. +if [ -f $HERE/env.sh ]; then + echo "Sourcing env file" + source $HERE/env.sh +fi + +uv run $HERE/run_server.py "$@" diff --git a/.evergreen/scripts/run-tests.sh b/.evergreen/scripts/run-tests.sh deleted file mode 100755 index ea923b3f5e..0000000000 --- a/.evergreen/scripts/run-tests.sh +++ /dev/null @@ -1,54 +0,0 @@ -#!/bin/bash - -# Disable xtrace -set +x -if [ -n "${MONGODB_STARTED}" ]; then - export PYMONGO_MUST_CONNECT=true -fi -if [ -n "${DISABLE_TEST_COMMANDS}" ]; then - export PYMONGO_DISABLE_TEST_COMMANDS=1 -fi -if [ -n "${test_encryption}" ]; then - # Disable xtrace (just in case it was accidentally set). - set +x - bash "${DRIVERS_TOOLS}"/.evergreen/csfle/await-servers.sh - export TEST_ENCRYPTION=1 - if [ -n "${test_encryption_pyopenssl}" ]; then - export TEST_ENCRYPTION_PYOPENSSL=1 - fi -fi -if [ -n "${test_crypt_shared}" ]; then - export TEST_CRYPT_SHARED=1 - export CRYPT_SHARED_LIB_PATH=${CRYPT_SHARED_LIB_PATH} -fi -if [ -n "${test_pyopenssl}" ]; then - export TEST_PYOPENSSL=1 -fi -if [ -n "${SETDEFAULTENCODING}" ]; then - export SETDEFAULTENCODING="${SETDEFAULTENCODING}" -fi -if [ -n "${test_loadbalancer}" ]; then - export TEST_LOADBALANCER=1 - export SINGLE_MONGOS_LB_URI="${SINGLE_MONGOS_LB_URI}" - export MULTI_MONGOS_LB_URI="${MULTI_MONGOS_LB_URI}" -fi -if [ -n "${test_serverless}" ]; then - export TEST_SERVERLESS=1 -fi -if [ -n "${TEST_INDEX_MANAGEMENT:-}" ]; then - export TEST_INDEX_MANAGEMENT=1 -fi -if [ -n "${SKIP_CSOT_TESTS}" ]; then - export SKIP_CSOT_TESTS=1 -fi -GREEN_FRAMEWORK=${GREEN_FRAMEWORK} \ - PYTHON_BINARY=${PYTHON_BINARY} \ - NO_EXT=${NO_EXT} \ - COVERAGE=${COVERAGE} \ - COMPRESSORS=${COMPRESSORS} \ - AUTH=${AUTH} \ - SSL=${SSL} \ - TEST_DATA_LAKE=${TEST_DATA_LAKE:-} \ - TEST_SUITES=${TEST_SUITES:-} \ - MONGODB_API_VERSION=${MONGODB_API_VERSION} \ - bash "${PROJECT_DIRECTORY}"/.evergreen/just.sh test-eg diff --git a/.evergreen/scripts/run-with-env.sh b/.evergreen/scripts/run-with-env.sh deleted file mode 100755 index 2fd073605d..0000000000 --- a/.evergreen/scripts/run-with-env.sh +++ /dev/null @@ -1,21 +0,0 @@ -#!/bin/bash -eu - -# Example use: bash run-with-env.sh run-tests.sh {args...} - -# Parameter expansion to get just the current directory's name -if [ "${PWD##*/}" == "src" ]; then - . .evergreen/scripts/env.sh - if [ -f ".evergreen/scripts/test-env.sh" ]; then - . .evergreen/scripts/test-env.sh - fi -else - . src/.evergreen/scripts/env.sh - if [ -f "src/.evergreen/scripts/test-env.sh" ]; then - . src/.evergreen/scripts/test-env.sh - fi -fi - -set -eu - -# shellcheck source=/dev/null -. "$@" diff --git a/.evergreen/scripts/run_server.py b/.evergreen/scripts/run_server.py new file mode 100644 index 0000000000..a35fbb57a8 --- /dev/null +++ b/.evergreen/scripts/run_server.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import os +from typing import Any + +from utils import DRIVERS_TOOLS, ROOT, get_test_options, run_command + + +def set_env(name: str, value: Any = "1") -> None: + os.environ[name] = str(value) + + +def start_server(): + opts, extra_opts = get_test_options( + "Run a MongoDB server. All given flags will be passed to run-orchestration.sh in DRIVERS_TOOLS.", + require_sub_test_name=False, + allow_extra_opts=True, + ) + test_name = opts.test_name + + # drivers-evergreen-tools expects the version variable to be named MONGODB_VERSION. + if "VERSION" in os.environ: + os.environ["MONGODB_VERSION"] = os.environ["VERSION"] + + if test_name == "auth_aws": + set_env("AUTH_AWS") + + elif test_name == "load_balancer": + set_env("LOAD_BALANCER") + + elif test_name == "search_index": + os.environ["TOPOLOGY"] = "replica_set" + os.environ["MONGODB_VERSION"] = "7.0" + + if not os.environ.get("TEST_CRYPT_SHARED"): + set_env("SKIP_CRYPT_SHARED") + + if opts.ssl: + extra_opts.append("--ssl") + if test_name != "ocsp": + certs = ROOT / "test/certificates" + set_env("TLS_CERT_KEY_FILE", certs / "client.pem") + set_env("TLS_PEM_KEY_FILE", certs / "server.pem") + set_env("TLS_CA_FILE", certs / "ca.pem") + + if opts.auth: + extra_opts.append("--auth") + + if opts.verbose: + extra_opts.append("-v") + elif opts.quiet: + extra_opts.append("-q") + + cmd = ["bash", f"{DRIVERS_TOOLS}/.evergreen/run-orchestration.sh", *extra_opts] + run_command(cmd, cwd=DRIVERS_TOOLS) + + +if __name__ == "__main__": + start_server() diff --git a/.evergreen/scripts/run_tests.py b/.evergreen/scripts/run_tests.py new file mode 100644 index 0000000000..9f700d70e0 --- /dev/null +++ b/.evergreen/scripts/run_tests.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import json +import logging +import os +import platform +import shutil +import sys +from datetime import datetime +from pathlib import Path +from shutil import which + +import pytest +from utils import DRIVERS_TOOLS, LOGGER, ROOT, run_command + +AUTH = os.environ.get("AUTH", "noauth") +SSL = os.environ.get("SSL", "nossl") +UV_ARGS = os.environ.get("UV_ARGS", "") +TEST_PERF = os.environ.get("TEST_PERF") +GREEN_FRAMEWORK = os.environ.get("GREEN_FRAMEWORK") +TEST_ARGS = os.environ.get("TEST_ARGS", "").split() +TEST_NAME = os.environ.get("TEST_NAME") +SUB_TEST_NAME = os.environ.get("SUB_TEST_NAME") + + +def handle_perf(start_time: datetime): + end_time = datetime.now() + elapsed_secs = (end_time - start_time).total_seconds() + with open("results.json") as fid: + results = json.load(fid) + LOGGER.info("results.json:\n%s", json.dumps(results, indent=2)) + + results = dict( + status="PASS", + exit_code=0, + test_file="BenchMarkTests", + start=int(start_time.timestamp()), + end=int(end_time.timestamp()), + elapsed=elapsed_secs, + ) + report = dict(failures=0, results=[results]) + LOGGER.info("report.json\n%s", json.dumps(report, indent=2)) + + with open("report.json", "w", newline="\n") as fid: + json.dump(report, fid) + + +def handle_green_framework() -> None: + if GREEN_FRAMEWORK == "eventlet": + import eventlet + + # https://github.com/eventlet/eventlet/issues/401 + eventlet.sleep() + eventlet.monkey_patch() + elif GREEN_FRAMEWORK == "gevent": + from gevent import monkey + + monkey.patch_all() + + # Never run async tests with a framework. + if len(TEST_ARGS) <= 1: + TEST_ARGS.extend(["-m", "not default_async and default"]) + else: + for i in range(len(TEST_ARGS) - 1): + if "-m" in TEST_ARGS[i]: + TEST_ARGS[i + 1] = f"not default_async and {TEST_ARGS[i + 1]}" + + LOGGER.info(f"Running tests with {GREEN_FRAMEWORK}...") + + +def handle_c_ext() -> None: + if platform.python_implementation() != "CPython": + return + sys.path.insert(0, str(ROOT / "tools")) + from fail_if_no_c import main as fail_if_no_c + + fail_if_no_c() + + +def handle_pymongocrypt() -> None: + import pymongocrypt + + LOGGER.info(f"pymongocrypt version: {pymongocrypt.__version__})") + LOGGER.info(f"libmongocrypt version: {pymongocrypt.libmongocrypt_version()})") + + +def handle_aws_lambda() -> None: + env = os.environ.copy() + target_dir = ROOT / "test/lambda" + env["TEST_LAMBDA_DIRECTORY"] = str(target_dir) + env.setdefault("AWS_REGION", "us-east-1") + dirs = ["pymongo", "gridfs", "bson"] + # Store the original .so files. + before_sos = [] + for dname in dirs: + before_sos.extend(f"{f.parent.name}/{f.name}" for f in (ROOT / dname).glob("*.so")) + # Build the c extensions. + docker = which("docker") or which("podman") + if not docker: + raise ValueError("Could not find docker!") + image = "quay.io/pypa/manylinux2014_x86_64:latest" + run_command( + f'{docker} run --rm -v "{ROOT}:/src" --platform linux/amd64 {image} /src/test/lambda/build_internal.sh' + ) + for dname in dirs: + target = ROOT / "test/lambda/mongodb" / dname + shutil.rmtree(target, ignore_errors=True) + shutil.copytree(ROOT / dname, target) + # Remove the original so files from the lambda directory. + for so_path in before_sos: + (ROOT / "test/lambda/mongodb" / so_path).unlink() + # Remove the new so files from the ROOT directory. + for dname in dirs: + so_paths = [f"{f.parent.name}/{f.name}" for f in (ROOT / dname).glob("*.so")] + for so_path in list(so_paths): + if so_path not in before_sos: + Path(so_path).unlink() + + script_name = "run-deployed-lambda-aws-tests.sh" + run_command(f"bash {DRIVERS_TOOLS}/.evergreen/aws_lambda/{script_name}", env=env) + + +def run() -> None: + # Handle green framework first so they can patch modules. + if GREEN_FRAMEWORK: + handle_green_framework() + + # Ensure C extensions if applicable. + if not os.environ.get("NO_EXT"): + handle_c_ext() + + if os.environ.get("PYMONGOCRYPT_LIB"): + handle_pymongocrypt() + + LOGGER.info(f"Test setup:\n{AUTH=}\n{SSL=}\n{UV_ARGS=}\n{TEST_ARGS=}") + + # Record the start time for a perf test. + if TEST_PERF: + start_time = datetime.now() + + # Run mod_wsgi tests using the helper. + if TEST_NAME == "mod_wsgi": + from mod_wsgi_tester import test_mod_wsgi + + test_mod_wsgi() + return + + # Send kms tests to run remotely. + if TEST_NAME == "kms" and SUB_TEST_NAME in ["azure", "gcp"]: + from kms_tester import test_kms_send_to_remote + + test_kms_send_to_remote(SUB_TEST_NAME) + return + + # Send ecs tests to run remotely. + if TEST_NAME == "auth_aws" and SUB_TEST_NAME == "ecs": + run_command(f"{DRIVERS_TOOLS}/.evergreen/auth_aws/aws_setup.sh ecs") + return + + # Send OIDC tests to run remotely. + if ( + TEST_NAME == "auth_oidc" + and SUB_TEST_NAME != "default" + and not SUB_TEST_NAME.endswith("-remote") + ): + from oidc_tester import test_oidc_send_to_remote + + test_oidc_send_to_remote(SUB_TEST_NAME) + return + + # Run deployed aws lambda tests. + if TEST_NAME == "aws_lambda": + handle_aws_lambda() + return + + if os.environ.get("DEBUG_LOG"): + TEST_ARGS.extend(f"-o log_cli_level={logging.DEBUG} -o log_cli=1".split()) + + # Run local tests. + ret = pytest.main(TEST_ARGS + sys.argv[1:]) + if ret != 0: + sys.exit(ret) + + # Handle perf test post actions. + if TEST_PERF: + handle_perf(start_time) + + +if __name__ == "__main__": + run() diff --git a/.evergreen/scripts/setup-dev-env.sh b/.evergreen/scripts/setup-dev-env.sh index ae4b44c626..6e6b5965bd 100755 --- a/.evergreen/scripts/setup-dev-env.sh +++ b/.evergreen/scripts/setup-dev-env.sh @@ -1,42 +1,59 @@ #!/bin/bash - +# Set up a development environment on an evergreen host. set -eu HERE=$(dirname ${BASH_SOURCE:-$0}) -pushd "$(dirname "$(dirname $HERE)")" > /dev/null +HERE="$( cd -- "$HERE" > /dev/null 2>&1 && pwd )" +ROOT=$(dirname "$(dirname $HERE)") +pushd $ROOT > /dev/null -# Source the env file to pick up common variables. +# Source the env files to pick up common variables. if [ -f $HERE/env.sh ]; then - source $HERE/env.sh + . $HERE/env.sh +fi +# PYTHON_BINARY or PYTHON_VERSION may be defined in test-env.sh. +if [ -f $HERE/test-env.sh ]; then + . $HERE/test-env.sh fi # Ensure dependencies are installed. -. $HERE/install-dependencies.sh +bash $HERE/install-dependencies.sh +# Get the appropriate UV_PYTHON. +. $ROOT/.evergreen/utils.sh -# Set the location of the python bin dir. -if [ "Windows_NT" = "${OS:-}" ]; then - BIN_DIR=.venv/Scripts -else - BIN_DIR=.venv/bin +if [ -z "${PYTHON_BINARY:-}" ]; then + if [ -n "${PYTHON_VERSION:-}" ]; then + PYTHON_BINARY=$(get_python_binary $PYTHON_VERSION) + else + PYTHON_BINARY=$(find_python3) + fi fi +export UV_PYTHON=${PYTHON_BINARY} +echo "Using python $UV_PYTHON" -# Ensure there is a python venv. -if [ ! -d $BIN_DIR ]; then - . .evergreen/utils.sh +# Add the default install path to the path if needed. +if [ -z "${PYMONGO_BIN_DIR:-}" ]; then + export PATH="$PATH:$HOME/.local/bin" +fi - if [ -z "${PYTHON_BINARY:-}" ]; then - PYTHON_BINARY=$(find_python3) - fi - export UV_PYTHON=${PYTHON_BINARY} - echo "export UV_PYTHON=$UV_PYTHON" >> $HERE/env.sh +# Set up venv, making sure c extensions build unless disabled. +if [ -z "${NO_EXT:-}" ]; then + export PYMONGO_C_EXT_MUST_BUILD=1 +fi +# Set up visual studio env on Windows spawn hosts. +if [ -f $HOME/.visualStudioEnv.sh ]; then + set +u + SSH_TTY=1 source $HOME/.visualStudioEnv.sh + set -u fi -echo "Using python $UV_PYTHON" uv sync --frozen -uv run --frozen --with pip pip install -e . + echo "Setting up python environment... done." # Ensure there is a pre-commit hook if there is a git checkout. if [ -d .git ] && [ ! -f .git/hooks/pre-commit ]; then uv run --frozen pre-commit install fi + +popd > /dev/null diff --git a/.evergreen/scripts/setup-encryption.sh b/.evergreen/scripts/setup-encryption.sh deleted file mode 100755 index 5b73240205..0000000000 --- a/.evergreen/scripts/setup-encryption.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/bash - -if [ -n "${test_encryption}" ]; then - bash .evergreen/setup-encryption.sh -fi diff --git a/.evergreen/scripts/setup-system.sh b/.evergreen/scripts/setup-system.sh index d78d924f6b..d8552e0ad2 100755 --- a/.evergreen/scripts/setup-system.sh +++ b/.evergreen/scripts/setup-system.sh @@ -1,5 +1,5 @@ #!/bin/bash - +# Set up the system on an evergreen host. set -eu HERE=$(dirname ${BASH_SOURCE:-$0}) @@ -7,8 +7,35 @@ pushd "$(dirname "$(dirname $HERE)")" echo "Setting up system..." bash .evergreen/scripts/configure-env.sh source .evergreen/scripts/env.sh -bash .evergreen/scripts/prepare-resources.sh bash $DRIVERS_TOOLS/.evergreen/setup.sh bash .evergreen/scripts/install-dependencies.sh popd + +# Enable core dumps if enabled on the machine +# Copied from https://github.com/mongodb/mongo/blob/master/etc/evergreen.yml +if [ -f /proc/self/coredump_filter ]; then + # Set the shell process (and its children processes) to dump ELF headers (bit 4), + # anonymous shared mappings (bit 1), and anonymous private mappings (bit 0). + echo 0x13 >/proc/self/coredump_filter + + if [ -f /sbin/sysctl ]; then + # Check that the core pattern is set explicitly on our distro image instead + # of being the OS's default value. This ensures that coredump names are consistent + # across distros and can be picked up by Evergreen. + core_pattern=$(/sbin/sysctl -n "kernel.core_pattern") + if [ "$core_pattern" = "dump_%e.%p.core" ]; then + echo "Enabling coredumps" + ulimit -c unlimited + fi + fi +fi + +if [ "$(uname -s)" = "Darwin" ]; then + core_pattern_mac=$(/usr/sbin/sysctl -n "kern.corefile") + if [ "$core_pattern_mac" = "dump_%N.%P.core" ]; then + echo "Enabling coredumps" + ulimit -c unlimited + fi +fi + echo "Setting up system... done." diff --git a/.evergreen/scripts/setup-tests.sh b/.evergreen/scripts/setup-tests.sh index 65462b2a68..0b75051a68 100755 --- a/.evergreen/scripts/setup-tests.sh +++ b/.evergreen/scripts/setup-tests.sh @@ -1,27 +1,23 @@ -#!/bin/bash -eux +#!/bin/bash +# Set up the test environment, including secrets and services. +set -eu -PROJECT_DIRECTORY="$(pwd)" -SCRIPT_DIR="$PROJECT_DIRECTORY/.evergreen/scripts" +# Supported/used environment variables: +# AUTH Set to enable authentication. Defaults to "noauth" +# SSL Set to enable SSL. Defaults to "nossl" +# GREEN_FRAMEWORK The green framework to test with, if any. +# COVERAGE If non-empty, run the test suite with coverage. +# COMPRESSORS If non-empty, install appropriate compressor. +# LIBMONGOCRYPT_URL The URL to download libmongocrypt. +# TEST_CRYPT_SHARED If non-empty, install crypt_shared lib. +# MONGODB_API_VERSION The mongodb api version to use in tests. +# MONGODB_URI If non-empty, use as the MONGODB_URI in tests. -if [ -f "$SCRIPT_DIR/test-env.sh" ]; then - echo "Reading $SCRIPT_DIR/test-env.sh file" - . "$SCRIPT_DIR/test-env.sh" - exit 0 -fi +SCRIPT_DIR=$(dirname ${BASH_SOURCE:-$0}) -cat < "$SCRIPT_DIR"/test-env.sh -export test_encryption="${test_encryption:-}" -export test_encryption_pyopenssl="${test_encryption_pyopenssl:-}" -export test_crypt_shared="${test_crypt_shared:-}" -export test_pyopenssl="${test_pyopenssl:-}" -export test_loadbalancer="${test_loadbalancer:-}" -export test_serverless="${test_serverless:-}" -export TEST_INDEX_MANAGEMENT="${TEST_INDEX_MANAGEMENT:-}" -export TEST_DATA_LAKE="${TEST_DATA_LAKE:-}" -export ORCHESTRATION_FILE="${ORCHESTRATION_FILE:-}" -export AUTH="${AUTH:-noauth}" -export SSL="${SSL:-nossl}" -export PYTHON_BINARY="${PYTHON_BINARY:-}" -EOT +# Try to source the env file. +if [ -f $SCRIPT_DIR/env.sh ]; then + source $SCRIPT_DIR/env.sh +fi -chmod +x "$SCRIPT_DIR"/test-env.sh +uv run $SCRIPT_DIR/setup_tests.py "$@" diff --git a/.evergreen/scripts/setup_tests.py b/.evergreen/scripts/setup_tests.py new file mode 100644 index 0000000000..2ee8aa12ee --- /dev/null +++ b/.evergreen/scripts/setup_tests.py @@ -0,0 +1,480 @@ +from __future__ import annotations + +import base64 +import io +import os +import platform +import shutil +import stat +import tarfile +from pathlib import Path +from urllib import request + +from utils import ( + DRIVERS_TOOLS, + ENV_FILE, + HERE, + LOGGER, + PLATFORM, + ROOT, + TEST_SUITE_MAP, + Distro, + get_test_options, + read_env, + run_command, + write_env, +) + +# Passthrough environment variables. +PASS_THROUGH_ENV = [ + "GREEN_FRAMEWORK", + "NO_EXT", + "MONGODB_API_VERSION", + "DEBUG_LOG", + "PYTHON_BINARY", + "PYTHON_VERSION", +] + +# Map the test name to test extra. +EXTRAS_MAP = { + "auth_aws": "aws", + "auth_oidc": "aws", + "encryption": "encryption", + "enterprise_auth": "gssapi", + "kms": "encryption", + "ocsp": "ocsp", + "pyopenssl": "ocsp", +} + + +# Map the test name to test group. +GROUP_MAP = dict(mockupdb="mockupdb", perf="perf") + +# The python version used for perf tests. +PERF_PYTHON_VERSION = "3.9.13" + + +def is_set(var: str) -> bool: + value = os.environ.get(var, "") + return len(value.strip()) > 0 + + +def get_distro() -> Distro: + name = "" + version_id = "" + arch = platform.machine() + with open("/etc/os-release") as fid: + for line in fid.readlines(): + line = line.replace('"', "") # noqa: PLW2901 + if line.startswith("NAME="): + _, _, name = line.strip().partition("=") + if line.startswith("VERSION_ID="): + _, _, version_id = line.strip().partition("=") + return Distro(name=name, version_id=version_id, arch=arch) + + +def setup_libmongocrypt(): + target = "" + if PLATFORM == "windows": + # PYTHON-2808 Ensure this machine has the CA cert for google KMS. + if is_set("TEST_FLE_GCP_AUTO"): + run_command('powershell.exe "Invoke-WebRequest -URI https://oauth2.googleapis.com/"') + target = "windows-test" + + elif PLATFORM == "darwin": + target = "macos" + + else: + distro = get_distro() + if distro.name.startswith("Debian"): + target = f"debian{distro.version_id}" + elif distro.name.startswith("Red Hat"): + if distro.version_id.startswith("7"): + target = "rhel-70-64-bit" + elif distro.version_id.startswith("8"): + if distro.arch == "aarch64": + target = "rhel-82-arm64" + else: + target = "rhel-80-64-bit" + + if not is_set("LIBMONGOCRYPT_URL"): + if not target: + raise ValueError("Cannot find libmongocrypt target for current platform!") + url = f"https://s3.amazonaws.com/mciuploads/libmongocrypt/{target}/master/latest/libmongocrypt.tar.gz" + else: + url = os.environ["LIBMONGOCRYPT_URL"] + + shutil.rmtree(HERE / "libmongocrypt", ignore_errors=True) + + LOGGER.info(f"Fetching {url}...") + with request.urlopen(request.Request(url), timeout=15.0) as response: # noqa: S310 + if response.status == 200: + fileobj = io.BytesIO(response.read()) + with tarfile.open("libmongocrypt.tar.gz", fileobj=fileobj) as fid: + fid.extractall(Path.cwd() / "libmongocrypt") + LOGGER.info(f"Fetching {url}... done.") + + run_command("ls -la libmongocrypt") + run_command("ls -la libmongocrypt/nocrypto") + + if PLATFORM == "windows": + # libmongocrypt's windows dll is not marked executable. + run_command("chmod +x libmongocrypt/nocrypto/bin/mongocrypt.dll") + + +def load_config_from_file(path: str | Path) -> dict[str, str]: + config = read_env(path) + for key, value in config.items(): + write_env(key, value) + return config + + +def get_secrets(name: str) -> dict[str, str]: + secrets_dir = Path(f"{DRIVERS_TOOLS}/.evergreen/secrets_handling") + run_command(f"bash {secrets_dir.as_posix()}/setup-secrets.sh {name}", cwd=secrets_dir) + return load_config_from_file(secrets_dir / "secrets-export.sh") + + +def handle_test_env() -> None: + opts, _ = get_test_options("Set up the test environment and services.") + test_name = opts.test_name + sub_test_name = opts.sub_test_name + AUTH = "auth" if opts.auth else "noauth" + SSL = "ssl" if opts.ssl else "nossl" + TEST_ARGS = "" + + # Start compiling the args we'll pass to uv. + UV_ARGS = ["--extra test --no-group dev"] + + test_title = test_name + if sub_test_name: + test_title += f" {sub_test_name}" + + # Create the test env file with the initial set of values. + with ENV_FILE.open("w", newline="\n") as fid: + fid.write("#!/usr/bin/env bash\n") + fid.write("set +x\n") + ENV_FILE.chmod(ENV_FILE.stat().st_mode | stat.S_IEXEC) + + write_env("PIP_QUIET") # Quiet by default. + write_env("PIP_PREFER_BINARY") # Prefer binary dists by default. + write_env("UV_FROZEN") # Do not modify lock files. + + # Skip CSOT tests on non-linux platforms. + if PLATFORM != "linux": + write_env("SKIP_CSOT_TESTS") + + # Set an environment variable for the test name and sub test name. + write_env(f"TEST_{test_name.upper()}") + write_env("TEST_NAME", test_name) + write_env("SUB_TEST_NAME", sub_test_name) + + # Handle pass through env vars. + for var in PASS_THROUGH_ENV: + if is_set(var) or getattr(opts, var.lower(), ""): + write_env(var, os.environ.get(var, getattr(opts, var.lower(), ""))) + + if extra := EXTRAS_MAP.get(test_name, ""): + UV_ARGS.append(f"--extra {extra}") + + if group := GROUP_MAP.get(test_name, ""): + UV_ARGS.append(f"--group {group}") + + if test_name == "auth_oidc": + from oidc_tester import setup_oidc + + config = setup_oidc(sub_test_name) + if not config: + AUTH = "noauth" + + if test_name in ["aws_lambda", "search_index"]: + env = os.environ.copy() + env["MONGODB_VERSION"] = "7.0" + env["LAMBDA_STACK_NAME"] = "dbx-python-lambda" + write_env("LAMBDA_STACK_NAME", env["LAMBDA_STACK_NAME"]) + run_command( + f"bash {DRIVERS_TOOLS}/.evergreen/atlas/setup-atlas-cluster.sh", + env=env, + cwd=DRIVERS_TOOLS, + ) + + if test_name == "search_index": + AUTH = "auth" + + if test_name == "ocsp": + SSL = "ssl" + + write_env("AUTH", AUTH) + write_env("SSL", SSL) + LOGGER.info(f"Setting up '{test_title}' with {AUTH=} and {SSL=}...") + + if test_name == "aws_lambda": + UV_ARGS.append("--group pip") + # Store AWS creds if they were given. + if "AWS_ACCESS_KEY_ID" in os.environ: + for key in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"]: + if key in os.environ: + write_env(key, os.environ[key]) + + if test_name == "data_lake": + # Stop any running mongo-orchestration which might be using the port. + run_command(f"bash {DRIVERS_TOOLS}/.evergreen/stop-orchestration.sh") + run_command(f"bash {DRIVERS_TOOLS}/.evergreen/atlas_data_lake/setup.sh") + AUTH = "auth" + + if AUTH != "noauth": + if test_name == "data_lake": + config = read_env(f"{DRIVERS_TOOLS}/.evergreen/atlas_data_lake/secrets-export.sh") + DB_USER = config["ADL_USERNAME"] + DB_PASSWORD = config["ADL_PASSWORD"] + elif test_name == "serverless": + run_command(f"bash {DRIVERS_TOOLS}/.evergreen/serverless/setup.sh") + config = read_env(f"{DRIVERS_TOOLS}/.evergreen/serverless/secrets-export.sh") + DB_USER = config["SERVERLESS_ATLAS_USER"] + DB_PASSWORD = config["SERVERLESS_ATLAS_PASSWORD"] + write_env("MONGODB_URI", config["SERVERLESS_URI"]) + write_env("SINGLE_MONGOS_LB_URI", config["SERVERLESS_URI"]) + write_env("MULTI_MONGOS_LB_URI", config["SERVERLESS_URI"]) + elif test_name == "auth_oidc": + DB_USER = config["OIDC_ADMIN_USER"] + DB_PASSWORD = config["OIDC_ADMIN_PWD"] + elif test_name == "search_index": + config = read_env(f"{DRIVERS_TOOLS}/.evergreen/atlas/secrets-export.sh") + DB_USER = config["DRIVERS_ATLAS_LAMBDA_USER"] + DB_PASSWORD = config["DRIVERS_ATLAS_LAMBDA_PASSWORD"] + write_env("MONGODB_URI", config["MONGODB_URI"]) + else: + DB_USER = "bob" + DB_PASSWORD = "pwd123" # noqa: S105 + write_env("DB_USER", DB_USER) + write_env("DB_PASSWORD", DB_PASSWORD) + LOGGER.info("Added auth, DB_USER: %s", DB_USER) + + if is_set("MONGODB_URI"): + write_env("PYMONGO_MUST_CONNECT", "true") + + if is_set("DISABLE_TEST_COMMANDS") or opts.disable_test_commands: + write_env("PYMONGO_DISABLE_TEST_COMMANDS", "1") + + if test_name == "enterprise_auth": + config = get_secrets("drivers/enterprise_auth") + if PLATFORM == "windows": + LOGGER.info("Setting GSSAPI_PASS") + write_env("GSSAPI_PASS", config["SASL_PASS"]) + write_env("GSSAPI_CANONICALIZE", "true") + else: + # BUILD-3830 + krb_conf = ROOT / ".evergreen/krb5.conf.empty" + krb_conf.touch() + write_env("KRB5_CONFIG", krb_conf) + LOGGER.info("Writing keytab") + keytab = base64.b64decode(config["KEYTAB_BASE64"]) + keytab_file = ROOT / ".evergreen/drivers.keytab" + with keytab_file.open("wb") as fid: + fid.write(keytab) + principal = config["PRINCIPAL"] + LOGGER.info("Running kinit") + os.environ["KRB5_CONFIG"] = str(krb_conf) + cmd = f"kinit -k -t {keytab_file} -p {principal}" + run_command(cmd) + + LOGGER.info("Setting GSSAPI variables") + write_env("GSSAPI_HOST", config["SASL_HOST"]) + write_env("GSSAPI_PORT", config["SASL_PORT"]) + write_env("GSSAPI_PRINCIPAL", config["PRINCIPAL"]) + + if test_name == "load_balancer": + SINGLE_MONGOS_LB_URI = os.environ.get( + "SINGLE_MONGOS_LB_URI", "mongodb://127.0.0.1:8000/?loadBalanced=true" + ) + MULTI_MONGOS_LB_URI = os.environ.get( + "MULTI_MONGOS_LB_URI", "mongodb://127.0.0.1:8001/?loadBalanced=true" + ) + if SSL != "nossl": + SINGLE_MONGOS_LB_URI += "&tls=true" + MULTI_MONGOS_LB_URI += "&tls=true" + write_env("SINGLE_MONGOS_LB_URI", SINGLE_MONGOS_LB_URI) + write_env("MULTI_MONGOS_LB_URI", MULTI_MONGOS_LB_URI) + if not DRIVERS_TOOLS: + raise RuntimeError("Missing DRIVERS_TOOLS") + cmd = f'bash "{DRIVERS_TOOLS}/.evergreen/run-load-balancer.sh" start' + run_command(cmd) + + if test_name == "mod_wsgi": + from mod_wsgi_tester import setup_mod_wsgi + + setup_mod_wsgi(sub_test_name) + + if test_name == "ocsp": + if sub_test_name: + os.environ["OCSP_SERVER_TYPE"] = sub_test_name + for name in ["OCSP_SERVER_TYPE", "ORCHESTRATION_FILE"]: + if name not in os.environ: + raise ValueError(f"Please set {name}") + + server_type = os.environ["OCSP_SERVER_TYPE"] + orch_file = os.environ["ORCHESTRATION_FILE"] + ocsp_algo = orch_file.split("-")[0] + if server_type == "no-responder": + tls_should_succeed = "false" if "mustStaple-disableStapling" in orch_file else "true" + else: + tls_should_succeed = "true" if "valid" in server_type else "false" + + write_env("OCSP_TLS_SHOULD_SUCCEED", tls_should_succeed) + write_env("CA_FILE", f"{DRIVERS_TOOLS}/.evergreen/ocsp/{ocsp_algo}/ca.pem") + + if server_type != "no-responder": + env = os.environ.copy() + env["SERVER_TYPE"] = server_type + env["OCSP_ALGORITHM"] = ocsp_algo + run_command(f"bash {DRIVERS_TOOLS}/.evergreen/ocsp/setup.sh", env=env) + + # The mock OCSP responder MUST BE started before the mongod as the mongod expects that + # a responder will be available upon startup. + version = os.environ.get("VERSION", "latest") + cmd = [ + "bash", + f"{DRIVERS_TOOLS}/.evergreen/run-orchestration.sh", + "--ssl", + "--version", + version, + ] + if opts.verbose: + cmd.append("-v") + elif opts.quiet: + cmd.append("-q") + run_command(cmd, cwd=DRIVERS_TOOLS) + + if SSL != "nossl": + if not DRIVERS_TOOLS: + raise RuntimeError("Missing DRIVERS_TOOLS") + write_env("CLIENT_PEM", f"{DRIVERS_TOOLS}/.evergreen/x509gen/client.pem") + write_env("CA_PEM", f"{DRIVERS_TOOLS}/.evergreen/x509gen/ca.pem") + + compressors = os.environ.get("COMPRESSORS") or opts.compressor + if compressors == "snappy": + UV_ARGS.append("--extra snappy") + elif compressors == "zstd": + UV_ARGS.append("--extra zstd") + + if test_name in ["encryption", "kms"]: + # Check for libmongocrypt download. + if not (ROOT / "libmongocrypt").exists(): + setup_libmongocrypt() + + # TODO: Test with 'pip install pymongocrypt' + UV_ARGS.append("--group pymongocrypt_source") + + # Use the nocrypto build to avoid dependency issues with older windows/python versions. + BASE = ROOT / "libmongocrypt/nocrypto" + if PLATFORM == "linux": + if (BASE / "lib/libmongocrypt.so").exists(): + PYMONGOCRYPT_LIB = BASE / "lib/libmongocrypt.so" + else: + PYMONGOCRYPT_LIB = BASE / "lib64/libmongocrypt.so" + elif PLATFORM == "darwin": + PYMONGOCRYPT_LIB = BASE / "lib/libmongocrypt.dylib" + else: + PYMONGOCRYPT_LIB = BASE / "bin/mongocrypt.dll" + if not PYMONGOCRYPT_LIB.exists(): + raise RuntimeError("Cannot find libmongocrypt shared object file") + write_env("PYMONGOCRYPT_LIB", PYMONGOCRYPT_LIB.as_posix()) + # PATH is updated by configure-env.sh for access to mongocryptd. + + if test_name == "encryption": + if not DRIVERS_TOOLS: + raise RuntimeError("Missing DRIVERS_TOOLS") + csfle_dir = Path(f"{DRIVERS_TOOLS}/.evergreen/csfle") + run_command(f"bash {csfle_dir}/setup-secrets.sh", cwd=csfle_dir) + load_config_from_file(csfle_dir / "secrets-export.sh") + run_command(f"bash {csfle_dir}/start-servers.sh") + + if sub_test_name == "pyopenssl": + UV_ARGS.append("--extra ocsp") + + if is_set("TEST_CRYPT_SHARED") or opts.crypt_shared: + config = read_env(f"{DRIVERS_TOOLS}/mo-expansion.sh") + CRYPT_SHARED_DIR = Path(config["CRYPT_SHARED_LIB_PATH"]).parent.as_posix() + LOGGER.info("Using crypt_shared_dir %s", CRYPT_SHARED_DIR) + if PLATFORM == "windows": + write_env("PATH", f"{CRYPT_SHARED_DIR}:$PATH") + else: + write_env( + "DYLD_FALLBACK_LIBRARY_PATH", + f"{CRYPT_SHARED_DIR}:${{DYLD_FALLBACK_LIBRARY_PATH:-}}", + ) + write_env("LD_LIBRARY_PATH", f"{CRYPT_SHARED_DIR}:${{LD_LIBRARY_PATH:-}}") + + if test_name == "kms": + from kms_tester import setup_kms + + setup_kms(sub_test_name) + + if test_name == "auth_aws" and sub_test_name != "ecs-remote": + auth_aws_dir = f"{DRIVERS_TOOLS}/.evergreen/auth_aws" + if "AWS_ROLE_SESSION_NAME" in os.environ: + write_env("AWS_ROLE_SESSION_NAME") + if sub_test_name != "ecs": + aws_setup = f"{auth_aws_dir}/aws_setup.sh" + run_command(f"bash {aws_setup} {sub_test_name}") + creds = read_env(f"{auth_aws_dir}/test-env.sh") + for name, value in creds.items(): + write_env(name, value) + else: + run_command(f"bash {auth_aws_dir}/setup-secrets.sh") + + if test_name == "atlas_connect": + get_secrets("drivers/atlas_connect") + # We do not want the default client_context to be initialized. + write_env("DISABLE_CONTEXT") + + if test_name == "perf": + data_dir = ROOT / "specifications/source/benchmarking/data" + if not data_dir.exists(): + run_command("git clone --depth 1 https://github.com/mongodb/specifications.git") + run_command("tar xf extended_bson.tgz", cwd=data_dir) + run_command("tar xf parallel.tgz", cwd=data_dir) + run_command("tar xf single_and_multi_document.tgz", cwd=data_dir) + write_env("TEST_PATH", str(data_dir)) + write_env("OUTPUT_FILE", str(ROOT / "results.json")) + # Overwrite the UV_PYTHON from the env.sh file. + write_env("UV_PYTHON", "") + + UV_ARGS.append(f"--python={PERF_PYTHON_VERSION}") + + # PYTHON-4769 Run perf_test.py directly otherwise pytest's test collection negatively + # affects the benchmark results. + if sub_test_name == "sync": + TEST_ARGS = f"test/performance/perf_test.py {TEST_ARGS}" + else: + TEST_ARGS = f"test/performance/async_perf_test.py {TEST_ARGS}" + + # Add coverage if requested. + # Only cover CPython. PyPy reports suspiciously low coverage. + if (is_set("COVERAGE") or opts.cov) and platform.python_implementation() == "CPython": + # Keep in sync with combine-coverage.sh. + # coverage >=5 is needed for relative_files=true. + UV_ARGS.append("--group coverage") + TEST_ARGS = f"{TEST_ARGS} --cov" + write_env("COVERAGE") + + if is_set("GREEN_FRAMEWORK") or opts.green_framework: + framework = opts.green_framework or os.environ["GREEN_FRAMEWORK"] + UV_ARGS.append(f"--group {framework}") + + else: + # Use --capture=tee-sys so pytest prints test output inline: + # https://docs.pytest.org/en/stable/how-to/capture-stdout-stderr.html + TEST_ARGS = f"-v --capture=tee-sys --durations=5 {TEST_ARGS}" + TEST_SUITE = TEST_SUITE_MAP.get(test_name) + if TEST_SUITE: + TEST_ARGS = f"-m {TEST_SUITE} {TEST_ARGS}" + + write_env("TEST_ARGS", TEST_ARGS) + write_env("UV_ARGS", " ".join(UV_ARGS)) + + LOGGER.info(f"Setting up test '{test_title}' with {AUTH=} and {SSL=}... done.") + + +if __name__ == "__main__": + handle_test_env() diff --git a/.evergreen/scripts/stop-load-balancer.sh b/.evergreen/scripts/stop-load-balancer.sh deleted file mode 100755 index 2d3c5366ec..0000000000 --- a/.evergreen/scripts/stop-load-balancer.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/bash - -cd "${DRIVERS_TOOLS}"/.evergreen || exit -DRIVERS_TOOLS=${DRIVERS_TOOLS} -bash "${DRIVERS_TOOLS}"/.evergreen/run-load-balancer.sh stop diff --git a/.evergreen/scripts/stop-server.sh b/.evergreen/scripts/stop-server.sh new file mode 100755 index 0000000000..7599387f5f --- /dev/null +++ b/.evergreen/scripts/stop-server.sh @@ -0,0 +1,14 @@ +#!/bin/bash +# Stop a server that was started using run-orchestration.sh in DRIVERS_TOOLS. +set -eu + +HERE=$(dirname ${BASH_SOURCE:-$0}) +HERE="$( cd -- "$HERE" > /dev/null 2>&1 && pwd )" + +# Try to source the env file. +if [ -f $HERE/env.sh ]; then + echo "Sourcing env file" + source $HERE/env.sh +fi + +bash ${DRIVERS_TOOLS}/.evergreen/stop-orchestration.sh diff --git a/.evergreen/scripts/teardown-tests.sh b/.evergreen/scripts/teardown-tests.sh new file mode 100755 index 0000000000..898425b6cf --- /dev/null +++ b/.evergreen/scripts/teardown-tests.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# Tear down any services that were used by tests. +set -eu + +SCRIPT_DIR=$(dirname ${BASH_SOURCE:-$0}) + +# Try to source the env file. +if [ -f $SCRIPT_DIR/env.sh ]; then + echo "Sourcing env inputs" + . $SCRIPT_DIR/env.sh +else + echo "Not sourcing env inputs" +fi + +# Handle test inputs. +if [ -f $SCRIPT_DIR/test-env.sh ]; then + echo "Sourcing test inputs" + . $SCRIPT_DIR/test-env.sh +else + echo "Missing test inputs, please run 'just setup-tests'" +fi + +# Teardown the test runner. +uv run $SCRIPT_DIR/teardown_tests.py diff --git a/.evergreen/scripts/teardown_tests.py b/.evergreen/scripts/teardown_tests.py new file mode 100644 index 0000000000..390e0a68eb --- /dev/null +++ b/.evergreen/scripts/teardown_tests.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import os +import shutil +import sys +from pathlib import Path + +from utils import DRIVERS_TOOLS, LOGGER, ROOT, run_command + +TEST_NAME = os.environ.get("TEST_NAME", "unconfigured") +SUB_TEST_NAME = os.environ.get("SUB_TEST_NAME") + +LOGGER.info(f"Tearing down tests of type '{TEST_NAME}'...") + +# Shut down csfle servers if applicable. +if TEST_NAME == "encryption": + run_command(f"bash {DRIVERS_TOOLS}/.evergreen/csfle/stop-servers.sh") + +# Shut down load balancer if applicable. +elif TEST_NAME == "load-balancer": + run_command(f"bash {DRIVERS_TOOLS}/.evergreen/run-load-balancer.sh stop") + +# Tear down kms VM if applicable. +elif TEST_NAME == "kms" and SUB_TEST_NAME in ["azure", "gcp"]: + from kms_tester import teardown_kms + + teardown_kms(SUB_TEST_NAME) + +# Tear down OIDC if applicable. +elif TEST_NAME == "auth_oidc": + from oidc_tester import teardown_oidc + + teardown_oidc(SUB_TEST_NAME) + +# Tear down ocsp if applicable. +elif TEST_NAME == "ocsp": + run_command(f"bash {DRIVERS_TOOLS}/.evergreen/ocsp/teardown.sh") + +# Tear down serverless if applicable. +elif TEST_NAME == "serverless": + run_command(f"bash {DRIVERS_TOOLS}/.evergreen/serverless/teardown.sh") + +# Tear down atlas cluster if applicable. +if TEST_NAME in ["aws_lambda", "search_index"]: + run_command(f"bash {DRIVERS_TOOLS}/.evergreen/atlas/teardown-atlas-cluster.sh") + +# Tear down auth_aws if applicable. +# We do not run web-identity hosts on macos, because the hosts lack permissions, +# so there is no reason to run the teardown, which would error with a 401. +elif TEST_NAME == "auth_aws" and sys.platform != "darwin": + run_command(f"bash {DRIVERS_TOOLS}/.evergreen/auth_aws/teardown.sh") + +# Tear down perf if applicable. +elif TEST_NAME == "perf": + shutil.rmtree(ROOT / "specifications", ignore_errors=True) + Path(os.environ["OUTPUT_FILE"]).unlink(missing_ok=True) + +# Tear down mog_wsgi if applicable. +elif TEST_NAME == "mod_wsgi": + from mod_wsgi_tester import teardown_mod_wsgi + + teardown_mod_wsgi() + +# Tear down data_lake if applicable. +elif TEST_NAME == "data_lake": + run_command(f"{DRIVERS_TOOLS}/.evergreen/atlas_data_lake/teardown.sh") + +# Tear down coverage if applicable. +if os.environ.get("COVERAGE"): + shutil.rmtree(".pytest_cache", ignore_errors=True) + +LOGGER.info(f"Tearing down tests of type '{TEST_NAME}'... done.") diff --git a/.evergreen/scripts/upload-coverage-report.sh b/.evergreen/scripts/upload-coverage-report.sh index 71a2a80bb8..895664cbf2 100755 --- a/.evergreen/scripts/upload-coverage-report.sh +++ b/.evergreen/scripts/upload-coverage-report.sh @@ -1,3 +1,4 @@ #!/bin/bash - +# Upload a coverate report to s3. +set -eu aws s3 cp htmlcov/ s3://"$1"/coverage/"$2"/"$3"/htmlcov/ --recursive --acl public-read --region us-east-1 diff --git a/.evergreen/scripts/utils.py b/.evergreen/scripts/utils.py new file mode 100644 index 0000000000..214a1fc347 --- /dev/null +++ b/.evergreen/scripts/utils.py @@ -0,0 +1,191 @@ +from __future__ import annotations + +import argparse +import dataclasses +import logging +import os +import shlex +import subprocess +import sys +from pathlib import Path +from typing import Any + +HERE = Path(__file__).absolute().parent +ROOT = HERE.parent.parent +DRIVERS_TOOLS = os.environ.get("DRIVERS_TOOLS", "").replace(os.sep, "/") +TMP_DRIVER_FILE = "/tmp/mongo-python-driver.tgz" # noqa: S108 + +LOGGER = logging.getLogger("test") +logging.basicConfig(level=logging.INFO, format="%(levelname)-8s %(message)s") +ENV_FILE = HERE / "test-env.sh" +PLATFORM = "windows" if os.name == "nt" else sys.platform.lower() + + +@dataclasses.dataclass +class Distro: + name: str + version_id: str + arch: str + + +# Map the test name to a test suite. +TEST_SUITE_MAP = { + "atlas_connect": "atlas_connect", + "auth_aws": "auth_aws", + "auth_oidc": "auth_oidc", + "data_lake": "data_lake", + "default": "", + "default_async": "default_async", + "default_sync": "default", + "encryption": "encryption", + "enterprise_auth": "auth", + "search_index": "search_index", + "kms": "kms", + "load_balancer": "load_balancer", + "mockupdb": "mockupdb", + "pyopenssl": "", + "ocsp": "ocsp", + "perf": "perf", + "serverless": "", +} + +# Tests that require a sub test suite. +SUB_TEST_REQUIRED = ["auth_aws", "auth_oidc", "kms", "mod_wsgi", "perf"] + +EXTRA_TESTS = ["mod_wsgi", "aws_lambda"] + +# Tests that do not use run-orchestration directly. +NO_RUN_ORCHESTRATION = ["auth_oidc", "atlas_connect", "data_lake", "mockupdb", "serverless", "ocsp"] + + +def get_test_options( + description, require_sub_test_name=True, allow_extra_opts=False +) -> tuple[argparse.Namespace, list[str]]: + parser = argparse.ArgumentParser( + description=description, formatter_class=argparse.RawDescriptionHelpFormatter + ) + if require_sub_test_name: + parser.add_argument( + "test_name", + choices=sorted(list(TEST_SUITE_MAP) + EXTRA_TESTS), + nargs="?", + default="default", + help="The optional name of the test suite to set up, typically the same name as a pytest marker.", + ) + parser.add_argument( + "sub_test_name", nargs="?", help="The optional sub test name, for example 'azure'." + ) + else: + parser.add_argument( + "test_name", + choices=set(TEST_SUITE_MAP) - set(NO_RUN_ORCHESTRATION), + nargs="?", + default="default", + help="The optional name of the test suite to be run, which informs the server configuration.", + ) + parser.add_argument( + "--verbose", "-v", action="store_true", help="Whether to log at the DEBUG level." + ) + parser.add_argument( + "--quiet", "-q", action="store_true", help="Whether to log at the WARNING level." + ) + parser.add_argument("--auth", action="store_true", help="Whether to add authentication.") + parser.add_argument("--ssl", action="store_true", help="Whether to add TLS configuration.") + + # Add the test modifiers. + if require_sub_test_name: + parser.add_argument( + "--debug-log", action="store_true", help="Enable pymongo standard logging." + ) + parser.add_argument("--cov", action="store_true", help="Add test coverage.") + parser.add_argument( + "--green-framework", + nargs=1, + choices=["eventlet", "gevent"], + help="Optional green framework to test against.", + ) + parser.add_argument( + "--compressor", + nargs=1, + choices=["zlib", "zstd", "snappy"], + help="Optional compression algorithm.", + ) + parser.add_argument("--crypt-shared", action="store_true", help="Test with crypt_shared.") + parser.add_argument("--no-ext", action="store_true", help="Run without c extensions.") + parser.add_argument( + "--mongodb-api-version", choices=["1"], help="MongoDB stable API version to use." + ) + parser.add_argument( + "--disable-test-commands", action="store_true", help="Disable test commands." + ) + + # Get the options. + if not allow_extra_opts: + opts, extra_opts = parser.parse_args(), [] + else: + opts, extra_opts = parser.parse_known_args() + if opts.verbose: + LOGGER.setLevel(logging.DEBUG) + elif opts.quiet: + LOGGER.setLevel(logging.WARNING) + + # Handle validation and environment variable overrides. + test_name = opts.test_name + sub_test_name = opts.sub_test_name if require_sub_test_name else "" + if require_sub_test_name and test_name in SUB_TEST_REQUIRED and not sub_test_name: + raise ValueError(f"Test '{test_name}' requires a sub_test_name") + if "auth" in test_name or os.environ.get("AUTH") == "auth": + opts.auth = True + # 'auth_aws ecs' shouldn't have extra auth set. + if test_name == "auth_aws" and sub_test_name == "ecs": + opts.auth = False + if os.environ.get("SSL") == "ssl": + opts.ssl = True + return opts, extra_opts + + +def read_env(path: Path | str) -> dict[str, str]: + config = dict() + with Path(path).open() as fid: + for line in fid.readlines(): + if "=" not in line: + continue + name, _, value = line.strip().partition("=") + if value.startswith(('"', "'")): + value = value[1:-1] + name = name.replace("export ", "") + config[name] = value + return config + + +def write_env(name: str, value: Any = "1") -> None: + with ENV_FILE.open("a", newline="\n") as fid: + # Remove any existing quote chars. + value = str(value).replace('"', "") + fid.write(f'export {name}="{value}"\n') + + +def run_command(cmd: str | list[str], **kwargs: Any) -> None: + if isinstance(cmd, list): + cmd = " ".join(cmd) + LOGGER.info("Running command '%s'...", cmd) + kwargs.setdefault("check", True) + # Prevent overriding the python used by other tools. + env = kwargs.pop("env", os.environ).copy() + if "UV_PYTHON" in env: + del env["UV_PYTHON"] + kwargs["env"] = env + try: + subprocess.run(shlex.split(cmd), **kwargs) # noqa: PLW1510, S603 + except subprocess.CalledProcessError as e: + LOGGER.error(e.output) + LOGGER.error(str(e)) + sys.exit(e.returncode) + LOGGER.info("Running command '%s'... done.", cmd) + + +def create_archive() -> str: + run_command("git add .", cwd=ROOT) + run_command('git commit -m "add files"', check=False, cwd=ROOT) + run_command(f"git archive -o {TMP_DRIVER_FILE} HEAD", cwd=ROOT) + return TMP_DRIVER_FILE diff --git a/.evergreen/scripts/windows-fix.sh b/.evergreen/scripts/windows-fix.sh deleted file mode 100755 index cb4fa44130..0000000000 --- a/.evergreen/scripts/windows-fix.sh +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/bash - -set +x -. src/.evergreen/scripts/env.sh -# shellcheck disable=SC2044 -for i in $(find "$DRIVERS_TOOLS"/.evergreen "$PROJECT_DIRECTORY"/.evergreen -name \*.sh); do - < "$i" tr -d '\r' >"$i".new - mv "$i".new "$i" -done -# Copy client certificate because symlinks do not work on Windows. -cp "$DRIVERS_TOOLS"/.evergreen/x509gen/client.pem "$MONGO_ORCHESTRATION_HOME"/lib/client.pem diff --git a/.evergreen/setup-encryption.sh b/.evergreen/setup-encryption.sh deleted file mode 100755 index b403ef9ca8..0000000000 --- a/.evergreen/setup-encryption.sh +++ /dev/null @@ -1,60 +0,0 @@ -#!/bin/bash -set -o errexit # Exit the script with error if any of the commands fail -set -o xtrace - -if [ -z "${DRIVERS_TOOLS}" ]; then - echo "Missing environment variable DRIVERS_TOOLS" - exit 1 -fi - -TARGET="" - -if [ "Windows_NT" = "${OS:-''}" ]; then # Magic variable in cygwin - # PYTHON-2808 Ensure this machine has the CA cert for google KMS. - powershell.exe "Invoke-WebRequest -URI https://oauth2.googleapis.com/" > /dev/null || true - TARGET="windows-test" -fi - -if [ "$(uname -s)" = "Darwin" ]; then - TARGET="macos" -fi - -if [ "$(uname -s)" = "Linux" ]; then - rhel_ver=$(awk -F'=' '/VERSION_ID/{ gsub(/"/,""); print $2}' /etc/os-release) - arch=$(uname -m) - echo "RHEL $rhel_ver $arch" - if [[ $rhel_ver =~ 7 ]]; then - TARGET="rhel-70-64-bit" - elif [[ $rhel_ver =~ 8 ]]; then - if [ "$arch" = "x86_64" ]; then - TARGET="rhel-80-64-bit" - elif [ "$arch" = "arm" ]; then - TARGET="rhel-82-arm64" - fi - fi -fi - -if [ -z "$LIBMONGOCRYPT_URL" ] && [ -n "$TARGET" ]; then - LIBMONGOCRYPT_URL="https://s3.amazonaws.com/mciuploads/libmongocrypt/$TARGET/master/latest/libmongocrypt.tar.gz" -fi - -if [ -z "$LIBMONGOCRYPT_URL" ]; then - echo "Cannot test client side encryption without LIBMONGOCRYPT_URL!" - exit 1 -fi -rm -rf libmongocrypt libmongocrypt.tar.gz -echo "Fetching $LIBMONGOCRYPT_URL..." -curl -O "$LIBMONGOCRYPT_URL" -echo "Fetching $LIBMONGOCRYPT_URL...done" -mkdir libmongocrypt -tar xzf libmongocrypt.tar.gz -C ./libmongocrypt -ls -la libmongocrypt -ls -la libmongocrypt/nocrypto - -if [ -z "${SKIP_SERVERS:-}" ]; then - PYTHON_BINARY_OLD=${PYTHON_BINARY} - export PYTHON_BINARY="" - bash "${DRIVERS_TOOLS}"/.evergreen/csfle/setup-secrets.sh - export PYTHON_BINARY=$PYTHON_BINARY_OLD - bash "${DRIVERS_TOOLS}"/.evergreen/csfle/start-servers.sh -fi diff --git a/.evergreen/setup-spawn-host.sh b/.evergreen/setup-spawn-host.sh index c20e1c756e..bada61e568 100755 --- a/.evergreen/setup-spawn-host.sh +++ b/.evergreen/setup-spawn-host.sh @@ -1,5 +1,5 @@ #!/bin/bash - +# Set up a remote evergreen spawn host. set -eu if [ -z "$1" ] @@ -16,4 +16,4 @@ rsync -az -e ssh --exclude '.git' --filter=':- .gitignore' -r . $target:$remote_ echo "Copying files to $target... done" ssh $target $remote_dir/.evergreen/scripts/setup-system.sh -ssh $target "cd $remote_dir && PYTHON_BINARY=${PYTHON_BINARY:-} just install" +ssh $target "cd $remote_dir && PYTHON_BINARY=${PYTHON_BINARY:-} .evergreen/scripts/setup-dev-env.sh" diff --git a/.evergreen/sync-spawn-host.sh b/.evergreen/sync-spawn-host.sh index de3374a008..61dd84ec22 100755 --- a/.evergreen/sync-spawn-host.sh +++ b/.evergreen/sync-spawn-host.sh @@ -1,4 +1,6 @@ #!/bin/bash +# Synchronize local files to a remote Evergreen spawn host. +set -eu if [ -z "$1" ] then @@ -7,9 +9,12 @@ fi target=$1 user=${target%@*} +remote_dir=/home/$user/mongo-python-driver +echo "Copying files to $target..." +rsync -az -e ssh --exclude '.git' --filter=':- .gitignore' -r . $target:$remote_dir +echo "Copying files to $target... done." echo "Syncing files to $target..." -rsync -haz -e ssh --exclude '.git' --filter=':- .gitignore' -r . $target:/home/$user/mongo-python-driver # shellcheck disable=SC2034 fswatch -o . | while read f; do rsync -hazv -e ssh --exclude '.git' --filter=':- .gitignore' -r . $target:/home/$user/mongo-python-driver; done echo "Syncing files to $target... done." diff --git a/.evergreen/teardown-encryption.sh b/.evergreen/teardown-encryption.sh deleted file mode 100755 index 5ce2f1d71b..0000000000 --- a/.evergreen/teardown-encryption.sh +++ /dev/null @@ -1,10 +0,0 @@ -#!/bin/bash -set -o errexit # Exit the script with error if any of the commands fail -set -o xtrace - -if [ -z "${DRIVERS_TOOLS}" ]; then - echo "Missing environment variable DRIVERS_TOOLS" -fi - -bash ${DRIVERS_TOOLS}/.evergreen/csfle/stop-servers.sh -rm -rf libmongocrypt/ libmongocrypt.tar.gz mongocryptd.pid diff --git a/.evergreen/utils.sh b/.evergreen/utils.sh index e044b3d766..faecde05fd 100755 --- a/.evergreen/utils.sh +++ b/.evergreen/utils.sh @@ -1,34 +1,32 @@ #!/bin/bash - +# Utility functions used by pymongo evergreen scripts. set -eu find_python3() { PYTHON="" - # Add a fallback system python3 if it is available and Python 3.9+. - if is_python_39 "$(command -v python3)"; then - PYTHON="$(command -v python3)" - fi # Find a suitable toolchain version, if available. if [ "$(uname -s)" = "Darwin" ]; then - # macos 11.00 - if [ -d "/Library/Frameworks/Python.Framework/Versions/3.10" ]; then - PYTHON="/Library/Frameworks/Python.Framework/Versions/3.10/bin/python3" - # macos 10.14 - elif [ -d "/Library/Frameworks/Python.Framework/Versions/3.9" ]; then - PYTHON="/Library/Frameworks/Python.Framework/Versions/3.9/bin/python3" - fi + PYTHON="/Library/Frameworks/Python.Framework/Versions/3.9/bin/python3" elif [ "Windows_NT" = "${OS:-}" ]; then # Magic variable in cygwin PYTHON="C:/python/Python39/python.exe" else # Prefer our own toolchain, fall back to mongodb toolchain if it has Python 3.9+. if [ -f "/opt/python/3.9/bin/python3" ]; then - PYTHON="/opt/python/3.9/bin/python3" + PYTHON="/opt/python/Current/bin/python3" + elif is_python_39 "$(command -v /opt/mongodbtoolchain/v5/bin/python3)"; then + PYTHON="/opt/mongodbtoolchain/v5/bin/python3" elif is_python_39 "$(command -v /opt/mongodbtoolchain/v4/bin/python3)"; then PYTHON="/opt/mongodbtoolchain/v4/bin/python3" elif is_python_39 "$(command -v /opt/mongodbtoolchain/v3/bin/python3)"; then PYTHON="/opt/mongodbtoolchain/v3/bin/python3" fi fi + # Add a fallback system python3 if it is available and Python 3.9+. + if [ -z "$PYTHON" ]; then + if is_python_39 "$(command -v python3)"; then + PYTHON="$(command -v python3)" + fi + fi if [ -z "$PYTHON" ]; then echo "Cannot test without python3.9+ installed!" exit 1 @@ -115,3 +113,24 @@ is_python_39() { return 1 fi } + + +# Function that gets a python binary given a python version string. +# Versions can be of the form 3.xx or pypy3.xx. +get_python_binary() { + version=$1 + if [ "$(uname -s)" = "Darwin" ]; then + PYTHON="/Library/Frameworks/Python.Framework/Versions/$version/bin/python3" + elif [ "Windows_NT" = "${OS:-}" ]; then + version=$(echo $version | cut -d. -f1,2 | sed 's/\.//g') + PYTHON="C:/python/Python$version/python.exe" + else + PYTHON="/opt/python/$version/bin/python3" + fi + if is_python_39 "$(command -v $PYTHON)"; then + echo "$PYTHON" + else + echo "Could not find suitable python binary for '$version'" >&2 + return 1 + fi +} diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index bb2418cf89..98cfa2f43f 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -54,7 +54,6 @@ jobs: queries: security-extended config: | paths-ignore: - - '.github/**' - 'doc/**' - 'tools/**' - 'test/**' diff --git a/.github/workflows/dist.yml b/.github/workflows/dist.yml index 5100c70d43..81f86721ef 100644 --- a/.github/workflows/dist.yml +++ b/.github/workflows/dist.yml @@ -34,12 +34,11 @@ jobs: # Github Actions doesn't support pairing matrix values together, let's improvise # https://github.com/github/feedback/discussions/7835#discussioncomment-1769026 buildplat: - - [ubuntu-20.04, "manylinux_x86_64", "cp3*-manylinux_x86_64"] - - [ubuntu-24.04-arm, "manylinux_aarch64", "cp3*-manylinux_aarch64"] - # Disabled pending PYTHON-5058 - # - [ubuntu-24.04, "manylinux_ppc64le", "cp3*-manylinux_ppc64le"] - # - [ubuntu-24.04, "manylinux_s390x", "cp3*-manylinux_s390x"] - - [ubuntu-20.04, "manylinux_i686", "cp3*-manylinux_i686"] + - [ubuntu-latest, "manylinux_x86_64", "cp3*-manylinux_x86_64"] + - [ubuntu-latest, "manylinux_aarch64", "cp3*-manylinux_aarch64"] + - [ubuntu-latest, "manylinux_ppc64le", "cp3*-manylinux_ppc64le"] + - [ubuntu-latest, "manylinux_s390x", "cp3*-manylinux_s390x"] + - [ubuntu-latest, "manylinux_i686", "cp3*-manylinux_i686"] - [windows-2019, "win_amd6", "cp3*-win_amd64"] - [windows-2019, "win32", "cp3*-win32"] - [macos-14, "macos", "cp*-macosx_*"] @@ -63,6 +62,10 @@ jobs: if: runner.os == 'Linux' uses: docker/setup-qemu-action@v3 with: + # setup-qemu-action by default uses `tonistiigi/binfmt:latest` image, + # which is out of date. This causes seg faults during build. + # Here we manually fix the version. + image: tonistiigi/binfmt:qemu-v8.1.5 platforms: all - name: Install cibuildwheel diff --git a/.github/workflows/release-python.yml b/.github/workflows/release-python.yml index bcf37d1a22..21c7ca5f7a 100644 --- a/.github/workflows/release-python.yml +++ b/.github/workflows/release-python.yml @@ -20,10 +20,11 @@ env: # Changes per repo PRODUCT_NAME: PyMongo # Changes per branch - SILK_ASSET_GROUP: mongodb-python-driver EVERGREEN_PROJECT: mongo-python-driver # Constant - DRY_RUN: ${{ inputs.dry_run == 'true' }} + # inputs will be empty on a scheduled run. so, we only set dry_run + # to 'false' when the input is set to 'false'. + DRY_RUN: ${{ ! contains(inputs.dry_run, 'false') }} FOLLOWING_VERSION: ${{ inputs.following_version || '' }} VERSION: ${{ inputs.version || '10.10.10.10' }} @@ -35,6 +36,7 @@ jobs: pre-publish: environment: release runs-on: ubuntu-latest + if: github.repository_owner == 'mongodb' || github.event_name == 'workflow_dispatch' permissions: id-token: write contents: write @@ -119,7 +121,6 @@ jobs: version: ${{ env.VERSION }} following_version: ${{ env.FOLLOWING_VERSION }} product_name: ${{ env.PRODUCT_NAME }} - silk_asset_group: ${{ env.SILK_ASSET_GROUP }} evergreen_project: ${{ env.EVERGREEN_PROJECT }} token: ${{ github.token }} dry_run: ${{ env.DRY_RUN }} diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index 3760e308a5..3c3eef989e 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -22,13 +22,13 @@ jobs: - uses: actions/checkout@v4 with: persist-credentials: false - - uses: actions/setup-python@v5 + - name: Install just + uses: extractions/setup-just@v3 + - name: Install uv + uses: astral-sh/setup-uv@v5 with: + enable-cache: true python-version: "3.9" - cache: 'pip' - cache-dependency-path: 'pyproject.toml' - - name: Install just - uses: extractions/setup-just@v2 - name: Install Python dependencies run: | just install @@ -53,50 +53,30 @@ jobs: # supercharge/mongodb-github-action requires containers so we don't test other platforms runs-on: ${{ matrix.os }} strategy: + fail-fast: false matrix: - os: [ubuntu-20.04] + os: [ubuntu-latest] python-version: ["3.9", "pypy-3.10", "3.13", "3.13t"] name: CPython ${{ matrix.python-version }}-${{ matrix.os }} steps: - uses: actions/checkout@v4 with: persist-credentials: false - - if: ${{ matrix.python-version == '3.13t' }} - name: Setup free-threaded Python - uses: deadsnakes/action@v3.2.0 - with: - python-version: 3.13 - nogil: true - - if: ${{ matrix.python-version != '3.13t' }} - name: Setup Python - uses: actions/setup-python@v5 + - name: Install just + uses: extractions/setup-just@v3 + - name: Install uv + uses: astral-sh/setup-uv@v5 with: + enable-cache: true python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: 'pyproject.toml' - allow-prereleases: true - - name: Install just - uses: extractions/setup-just@v2 - name: Install dependencies - run: | - if [[ "${{ matrix.python-version }}" == "3.13t" ]]; then - # Just can't be installed on 3.13t, use pytest directly. - pip install . - pip install -r requirements/test.txt - else - just install - fi + run: just install - name: Start MongoDB uses: supercharge/mongodb-github-action@1.12.0 with: mongodb-version: 6.0 - name: Run tests - run: | - if [[ "${{ matrix.python-version }}" == "3.13t" ]]; then - pytest -v --durations=5 --maxfail=10 - else - just test - fi + run: just test doctest: runs-on: ubuntu-latest @@ -105,24 +85,21 @@ jobs: - uses: actions/checkout@v4 with: persist-credentials: false - - name: Setup Python - uses: actions/setup-python@v5 + - name: Install just + uses: extractions/setup-just@v3 + - name: Install uv + uses: astral-sh/setup-uv@v5 with: + enable-cache: true python-version: "3.9" - cache: 'pip' - cache-dependency-path: 'pyproject.toml' - - name: Install just - uses: extractions/setup-just@v2 - name: Start MongoDB uses: supercharge/mongodb-github-action@1.12.0 with: mongodb-version: '8.0.0-rc4' - name: Install dependencies - run: | - just install + run: just install - name: Run tests - run: | - just docs-test + run: just docs-test docs: name: Docs Checks @@ -131,20 +108,17 @@ jobs: - uses: actions/checkout@v4 with: persist-credentials: false - - uses: actions/setup-python@v5 + - name: Install uv + uses: astral-sh/setup-uv@v5 with: - cache: 'pip' - cache-dependency-path: 'pyproject.toml' - # Build docs on lowest supported Python for furo - python-version: '3.9' + enable-cache: true + python-version: "3.9" - name: Install just - uses: extractions/setup-just@v2 + uses: extractions/setup-just@v3 - name: Install dependencies - run: | - just install + run: just install - name: Build docs - run: | - just docs + run: just docs linkcheck: name: Link Check @@ -153,20 +127,17 @@ jobs: - uses: actions/checkout@v4 with: persist-credentials: false - - uses: actions/setup-python@v5 + - name: Install uv + uses: astral-sh/setup-uv@v5 with: - cache: 'pip' - cache-dependency-path: 'pyproject.toml' - # Build docs on lowest supported Python for furo - python-version: '3.9' + enable-cache: true + python-version: "3.9" - name: Install just - uses: extractions/setup-just@v2 + uses: extractions/setup-just@v3 - name: Install dependencies - run: | - just install + run: just install - name: Build docs - run: | - just docs-linkcheck + run: just docs-linkcheck typing: name: Typing Tests @@ -178,13 +149,13 @@ jobs: - uses: actions/checkout@v4 with: persist-credentials: false - - uses: actions/setup-python@v5 + - name: Install uv + uses: astral-sh/setup-uv@v5 with: + enable-cache: true python-version: "${{matrix.python}}" - cache: 'pip' - cache-dependency-path: 'pyproject.toml' - name: Install just - uses: extractions/setup-just@v2 + uses: extractions/setup-just@v3 - name: Install dependencies run: | just install diff --git a/.gitignore b/.gitignore index 2582c517fd..a88a7556e2 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,7 @@ mongocryptd.pid .idea/ .vscode/ .nova/ +.temp/ venv/ secrets-export.sh libmongocrypt.tar.gz @@ -26,13 +27,16 @@ libmongocrypt/ expansion.yml *expansions.yml .evergreen/scripts/env.sh +.evergreen/scripts/test-env.sh +specifications/ +results.json # Lambda temp files test/lambda/.aws-sam -test/lambda/env.json test/lambda/mongodb/pymongo/* test/lambda/mongodb/gridfs/* test/lambda/mongodb/bson/* +test/lambda/*.json # test results and logs xunit-results/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a0b06ab0dc..a570e55ad1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,6 +6,7 @@ repos: - id: check-added-large-files - id: check-case-conflict - id: check-toml + - id: check-json - id: check-yaml exclude: template.yaml - id: debug-statements @@ -115,3 +116,9 @@ repos: (?x)( .evergreen/retry-with-backoff.sh ) + - id: generate-config + name: generate-config + entry: .evergreen/scripts/generate-config.sh + language: python + require_serial: true + additional_dependencies: ["shrub.py>=3.9.0", "pyyaml>=6.0.2"] diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 536110fcfc..60583022b7 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -178,7 +178,7 @@ documentation including narrative docs, and the [Sphinx docstring format](https: You can build the documentation locally by running: ```bash -just docs-build +just docs ``` When updating docs, it can be helpful to run the live docs server as: @@ -204,25 +204,173 @@ the pages will re-render and the browser will automatically refresh. `just test test/test_change_stream.py::TestUnifiedChangeStreamsErrors::test_change_stream_errors_on_ElectionInProgress`. - Use the `-k` argument to select tests by pattern. -## Running Load Balancer Tests Locally -- Install `haproxy` (available as `brew install haproxy` on macOS). -- Clone `drivers-evergreen-tools`: - `git clone git@github.com:mongodb-labs/drivers-evergreen-tools.git`. -- Start the servers using - `LOAD_BALANCER=true TOPOLOGY=sharded_cluster AUTH=noauth SSL=nossl MONGODB_VERSION=6.0 DRIVERS_TOOLS=$PWD/drivers-evergreen-tools MONGO_ORCHESTRATION_HOME=$PWD/drivers-evergreen-tools/.evergreen/orchestration $PWD/drivers-evergreen-tools/.evergreen/run-orchestration.sh`. -- Start the load balancer using: - `MONGODB_URI='mongodb://localhost:27017,localhost:27018/' $PWD/drivers-evergreen-tools/.evergreen/run-load-balancer.sh start`. -- Run the tests from the `pymongo` checkout directory using: - `TEST_LOADBALANCER=1 just test-eg`. - -## Running Encryption Tests Locally +## Running tests that require secrets, services, or other configuration + +### Prerequisites + - Clone `drivers-evergreen-tools`: - `git clone git@github.com:mongodb-labs/drivers-evergreen-tools.git`. -- Run `export DRIVERS_TOOLS=$PWD/drivers-evergreen-tools` -- Run `AWS_PROFILE= just setup-encryption` after setting up your AWS profile with `aws configure sso`. -- Run the tests with `TEST_ENCRYPTION=1 just test-eg`. -- When done, run `just teardown-encryption` to clean up. + `git clone git@github.com:mongodb-labs/drivers-evergreen-tools.git`. +- Run `export DRIVERS_TOOLS=$PWD/drivers-evergreen-tools`. This can be put into a `.bashrc` file + for convenience. +- Set up access to [Drivers test secrets](https://github.com/mongodb-labs/drivers-evergreen-tools/tree/master/.evergreen/secrets_handling#secrets-handling). + +### Usage + +- Run `just run-server` with optional args to set up the server. All given options will be passed to + `run-orchestration.sh` in `$DRIVERS_TOOLS`. See `$DRIVERS_TOOLS/evergreen/run-orchestration.sh -h` + for a full list of options. +- Run `just setup-tests` with optional args to set up the test environment, secrets, etc. + See `just setup-tests -h` for a full list of available options. +- Run `just run-tests` to run the tests in an appropriate Python environment. +- When done, run `just teardown-tests` to clean up and `just stop-server` to stop the server. + +### Encryption tests + +- Run `just run-server` to start the server. +- Run `just setup-tests encryption`. +- Run the tests with `just run-tests`. + +### Load balancer tests + +- Install `haproxy` (available as `brew install haproxy` on macOS). +- Start the server with `just run-server load_balancer`. +- Set up the test with `just setup-tests load_balancer`. +- Run the tests with `just run-tests`. + +### AWS auth tests + +- Run `just run-server auth_aws` to start the server. +- Run `just setup-tests auth_aws ` to set up the AWS test. +- Run the tests with `just run-tests`. + +### OIDC auth tests + +- Run `just setup-tests auth_oidc ` to set up the OIDC test. +- Run the tests with `just run-tests`. + +The supported types are [`default`, `azure`, `gcp`, `eks`, `aks`, and `gke`]. +For the `eks` test, you will need to set up access to the `drivers-test-secrets-role`, see the [Wiki](https://wiki.corp.mongodb.com/spaces/DRIVERS/pages/239737385/Using+AWS+Secrets+Manager+to+Store+Testing+Secrets). + +### KMS tests + +For KMS tests that are run locally, and expected to fail, in this case using `azure`: + +- Run `just run-server`. +- Run `just setup-tests kms azure-fail`. +- Run `just run-tests`. + +For KMS tests that run remotely and are expected to pass, in this case using `gcp`: + +- Run `just setup-tests kms gcp`. +- Run `just run-tests`. + +### Enterprise Auth tests + +Note: these tests can only be run from an Evergreen host. + +- Run `just run-server enterprise_auth`. +- Run `just setup-tests enterprise_auth`. +- Run `just run-tests`. + +### Atlas Connect tests + +- Run `just setup-tests atlas_connect`. +- Run `just run-tests`. + +### Search Index tests + +- Run `just run-server search_index`. +- Run `just setup-tests search_index`. +- Run `just run-tests`. + +### MockupDB tests + +- Run `just setup-tests mockupdb`. +- Run `just run-tests`. + +### Doc tests + +The doc tests require a running server. + +- Run `just run-server`. +- Run `just docs-test`. + +### Free-threaded Python Tests + +In the evergreen builds, the tests are configured to use the free-threaded python from the toolchain. +Locally you can run: + +- Run `just run-server`. +- Run `just setup-tests`. +- Run `UV_PYTHON=3.13t just run-tests`. + +### AWS Lambda tests + +You will need to set up access to the `drivers-test-secrets-role`, see the [Wiki](https://wiki.corp.mongodb.com/spaces/DRIVERS/pages/239737385/Using+AWS+Secrets+Manager+to+Store+Testing+Secrets). + +- Run `just setup-tests aws_lambda`. +- Run `just run-tests`. + +### mod_wsgi tests + +Note: these tests can only be run from an Evergreen Linux host that has the Python toolchain. + +- Run `just run-server`. +- Run `just setup-tests mod_wsgi `. +- Run `just run-tests`. + +The `mode` can be `standalone` or `embedded`. For the `replica_set` version of the tests, use +`TOPOLOGY=replica_set just run-server`. + +### Atlas Data Lake tests. + +You must have `docker` or `podman` installed locally. + +- Run `just setup-tests data_lake`. +- Run `just run-tests`. + +### OCSP tests + +- Export the orchestration file, e.g. `export ORCHESTRATION_FILE=rsa-basic-tls-ocsp-disableStapling.json`. +This corresponds to a config file in `$DRIVERS_TOOLS/.evergreen/orchestration/configs/servers`. +MongoDB servers on MacOS and Windows do not staple OCSP responses and only support RSA. +NOTE: because the mock ocsp responder MUST be started prior to the server starting, the ocsp tests start the server +as part of `setup-tests`. + +- Run `just setup-tests ocsp ` (options are "valid", "revoked", "valid-delegate", "revoked-delegate"). +- Run `just run-tests` + +If you are running one of the `no-responder` tests, omit the `run-server` step. + +### Perf Tests + +- Start the appropriate server, e.g. `just run-server --version=v8.0-perf --ssl`. +- Set up the tests with `sync` or `async`: `just setup-tests perf sync`. +- Run the tests: `just run-tests`. + +## Enable Debug Logs + +- Use `-o log_cli_level="DEBUG" -o log_cli=1` with `just test` or `pytest`. +- Add `log_cli_level = "DEBUG` and `log_cli = 1` to the `tool.pytest.ini_options` section in `pyproject.toml` for Evergreen patches or to enable debug logs by default on your machine. +- You can also set `DEBUG_LOG=1` and run either `just setup-tests` or `just-test`. +- Finally, you can use `just setup-tests --debug-log`. +- For evergreen patch builds, you can use `evergreen patch --param DEBUG_LOG=1` to enable debug logs for the patch. + +## Adding a new test suite + +- If adding new tests files that should only be run for that test suite, add a pytest marker to the file and add + to the list of pytest markers in `pyproject.toml`. Then add the test suite to the `TEST_SUITE_MAP` in `.evergreen/scripts/utils.py`. If for some reason it is not a pytest-runnable test, add it to the list of `EXTRA_TESTS` instead. +- If the test uses Atlas or otherwise doesn't use `run-orchestration.sh`, add it to the `NO_RUN_ORCHESTRATION` list in + `.evergreen/scripts/utils.py`. +- If there is something special required to run the local server or there is an extra flag that should always be set + like `AUTH`, add that logic to `.evergreen/scripts/run_server.py`. +- The bulk of the logic will typically be in `.evergreen/scripts/setup_tests.py`. This is where you should fetch secrets and make them available using `write_env`, start services, and write other env vars needed using `write_env`. +- If there are any special test considerations, including not running `pytest` at all, handle it in `.evergreen/scripts/run_tests.py`. +- If there are any services or atlas clusters to teardown, handle them in `.evergreen/scripts/teardown_tests.py`. +- Add functions to generate the test variant(s) and task(s) to the `.evergreen/scripts/generate_config.py`. +- Regenerate the test variants and tasks using `pre-commit run --all-files generate-config`. +- Make sure to add instructions for running the test suite to `CONTRIBUTING.md`. ## Re-sync Spec Tests @@ -261,6 +409,11 @@ To prevent the `synchro` hook from accidentally overwriting code, it first check of a file is changing and not its async counterpart, and will fail. In the unlikely scenario that you want to override this behavior, first export `OVERRIDE_SYNCHRO_CHECK=1`. +Sometimes, the `synchro` hook will fail and introduce changes many previously unmodified files. This is due to static +Python errors, such as missing imports, incorrect syntax, or other fatal typos. To resolve these issues, +run `pre-commit run --all-files --hook-stage manual ruff` and fix all reported errors before running the `synchro` +hook again. + ## Converting a test to async The `tools/convert_test_to_async.py` script takes in an existing synchronous test file and outputs a partially-converted asynchronous version of the same name to the `test/asynchronous` directory. diff --git a/bson/__init__.py b/bson/__init__.py index fc6efe0d59..790ac06ef1 100644 --- a/bson/__init__.py +++ b/bson/__init__.py @@ -1386,7 +1386,7 @@ def is_valid(bson: bytes) -> bool: :param bson: the data to be validated """ if not isinstance(bson, bytes): - raise TypeError("BSON data must be an instance of a subclass of bytes") + raise TypeError(f"BSON data must be an instance of a subclass of bytes, not {type(bson)}") try: _bson_to_dict(bson, DEFAULT_CODEC_OPTIONS) diff --git a/bson/binary.py b/bson/binary.py index f90dce226c..6698e55ccc 100644 --- a/bson/binary.py +++ b/bson/binary.py @@ -14,7 +14,6 @@ from __future__ import annotations import struct -from dataclasses import dataclass from enum import Enum from typing import TYPE_CHECKING, Any, Optional, Sequence, Tuple, Type, Union, overload from uuid import UUID @@ -227,7 +226,6 @@ class BinaryVectorDtype(Enum): PACKED_BIT = b"\x10" -@dataclass class BinaryVector: """Vector of numbers along with metadata for binary interoperability. .. versionadded:: 4.10 @@ -247,6 +245,16 @@ def __init__(self, data: Sequence[float | int], dtype: BinaryVectorDtype, paddin self.dtype = dtype self.padding = padding + def __repr__(self) -> str: + return f"BinaryVector(dtype={self.dtype}, padding={self.padding}, data={self.data})" + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, BinaryVector): + return False + return ( + self.dtype == other.dtype and self.padding == other.padding and self.data == other.data + ) + class Binary(bytes): """Representation of BSON binary data. @@ -290,7 +298,7 @@ def __new__( subtype: int = BINARY_SUBTYPE, ) -> Binary: if not isinstance(subtype, int): - raise TypeError("subtype must be an instance of int") + raise TypeError(f"subtype must be an instance of int, not {type(subtype)}") if subtype >= 256 or subtype < 0: raise ValueError("subtype must be contained in [0, 256)") # Support any type that implements the buffer protocol. @@ -321,7 +329,7 @@ def from_uuid( .. versionadded:: 3.11 """ if not isinstance(uuid, UUID): - raise TypeError("uuid must be an instance of uuid.UUID") + raise TypeError(f"uuid must be an instance of uuid.UUID, not {type(uuid)}") if uuid_representation not in ALL_UUID_REPRESENTATIONS: raise ValueError( @@ -450,6 +458,10 @@ def from_vector( raise ValueError(f"padding does not apply to {dtype=}") elif dtype == BinaryVectorDtype.PACKED_BIT: # pack ints in [0, 255] as unsigned uint8 format_str = "B" + if 0 <= padding > 7: + raise ValueError(f"{padding=}. It must be in [0,1, ..7].") + if padding and not vector: + raise ValueError("Empty vector with non-zero padding.") elif dtype == BinaryVectorDtype.FLOAT32: # pack floats as float32 format_str = "f" if padding: @@ -470,7 +482,7 @@ def as_vector(self) -> BinaryVector: """ if self.subtype != VECTOR_SUBTYPE: - raise ValueError(f"Cannot decode subtype {self.subtype} as a vector.") + raise ValueError(f"Cannot decode subtype {self.subtype} as a vector") position = 0 dtype, padding = struct.unpack_from(" Code: if not isinstance(code, str): - raise TypeError("code must be an instance of str") + raise TypeError(f"code must be an instance of str, not {type(code)}") self = str.__new__(cls, code) @@ -67,7 +67,7 @@ def __new__( if scope is not None: if not isinstance(scope, _Mapping): - raise TypeError("scope must be an instance of dict") + raise TypeError(f"scope must be an instance of dict, not {type(scope)}") if self.__scope is not None: self.__scope.update(scope) # type: ignore else: diff --git a/bson/codec_options.py b/bson/codec_options.py index 3a0b83b7be..258a777a1b 100644 --- a/bson/codec_options.py +++ b/bson/codec_options.py @@ -401,17 +401,23 @@ def __new__( "uuid_representation must be a value from bson.binary.UuidRepresentation" ) if not isinstance(unicode_decode_error_handler, str): - raise ValueError("unicode_decode_error_handler must be a string") + raise ValueError( + f"unicode_decode_error_handler must be a string, not {type(unicode_decode_error_handler)}" + ) if tzinfo is not None: if not isinstance(tzinfo, datetime.tzinfo): - raise TypeError("tzinfo must be an instance of datetime.tzinfo") + raise TypeError( + f"tzinfo must be an instance of datetime.tzinfo, not {type(tzinfo)}" + ) if not tz_aware: raise ValueError("cannot specify tzinfo without also setting tz_aware=True") type_registry = type_registry or TypeRegistry() if not isinstance(type_registry, TypeRegistry): - raise TypeError("type_registry must be an instance of TypeRegistry") + raise TypeError( + f"type_registry must be an instance of TypeRegistry, not {type(type_registry)}" + ) return tuple.__new__( cls, diff --git a/bson/dbref.py b/bson/dbref.py index 6c21b8162c..40bdb73cff 100644 --- a/bson/dbref.py +++ b/bson/dbref.py @@ -56,9 +56,9 @@ def __init__( .. seealso:: The MongoDB documentation on `dbrefs `_. """ if not isinstance(collection, str): - raise TypeError("collection must be an instance of str") + raise TypeError(f"collection must be an instance of str, not {type(collection)}") if database is not None and not isinstance(database, str): - raise TypeError("database must be an instance of str") + raise TypeError(f"database must be an instance of str, not {type(database)}") self.__collection = collection self.__id = id diff --git a/bson/decimal128.py b/bson/decimal128.py index 016afb5eb8..92c054d878 100644 --- a/bson/decimal128.py +++ b/bson/decimal128.py @@ -277,7 +277,7 @@ def from_bid(cls: Type[Decimal128], value: bytes) -> Decimal128: point in Binary Integer Decimal (BID) format). """ if not isinstance(value, bytes): - raise TypeError("value must be an instance of bytes") + raise TypeError(f"value must be an instance of bytes, not {type(value)}") if len(value) != 16: raise ValueError("value must be exactly 16 bytes") return cls((_UNPACK_64(value[8:])[0], _UNPACK_64(value[:8])[0])) # type: ignore diff --git a/bson/timestamp.py b/bson/timestamp.py index 3e76e7baad..949bd7b36c 100644 --- a/bson/timestamp.py +++ b/bson/timestamp.py @@ -58,9 +58,9 @@ def __init__(self, time: Union[datetime.datetime, int], inc: int) -> None: time = time - offset time = int(calendar.timegm(time.timetuple())) if not isinstance(time, int): - raise TypeError("time must be an instance of int") + raise TypeError(f"time must be an instance of int, not {type(time)}") if not isinstance(inc, int): - raise TypeError("inc must be an instance of int") + raise TypeError(f"inc must be an instance of int, not {type(inc)}") if not 0 <= time < UPPERBOUND: raise ValueError("time must be contained in [0, 2**32)") if not 0 <= inc < UPPERBOUND: diff --git a/doc/api/bson/binary.rst b/doc/api/bson/binary.rst index 084fd02d50..7084a45b4e 100644 --- a/doc/api/bson/binary.rst +++ b/doc/api/bson/binary.rst @@ -16,6 +16,7 @@ .. autodata:: MD5_SUBTYPE .. autodata:: COLUMN_SUBTYPE .. autodata:: SENSITIVE_SUBTYPE + .. autodata:: VECTOR_SUBTYPE .. autodata:: USER_DEFINED_SUBTYPE .. autoclass:: UuidRepresentation diff --git a/doc/api/gridfs/asynchronous/grid_file.rst b/doc/api/gridfs/asynchronous/grid_file.rst new file mode 100644 index 0000000000..fbf34adc8a --- /dev/null +++ b/doc/api/gridfs/asynchronous/grid_file.rst @@ -0,0 +1,19 @@ +:mod:`grid_file` -- Async tools for representing files stored in GridFS +======================================================================= + +.. automodule:: gridfs.asynchronous.grid_file + :synopsis: Async tools for representing files stored in GridFS + + .. autoclass:: AsyncGridIn + :members: + + .. autoattribute:: _id + + .. autoclass:: AsyncGridOut + :members: + + .. autoattribute:: _id + .. automethod:: __aiter__ + + .. autoclass:: AsyncGridOutCursor + :members: diff --git a/doc/api/gridfs/asynchronous/index.rst b/doc/api/gridfs/asynchronous/index.rst new file mode 100644 index 0000000000..0904d10f98 --- /dev/null +++ b/doc/api/gridfs/asynchronous/index.rst @@ -0,0 +1,18 @@ +:mod:`gridfs async` -- Async tools for working with GridFS +========================================================== + +.. warning:: This API is currently in beta, meaning the classes, methods, + and behaviors described within may change before the full release. + If you come across any bugs during your use of this API, + please file a Jira ticket in the "Python Driver" project at https://jira.mongodb.org/browse/PYTHON. + +.. automodule:: gridfs.asynchronous + :synopsis: Async tools for working with GridFS + :members: AsyncGridFS, AsyncGridFSBucket + +Sub-modules: + +.. toctree:: + :maxdepth: 2 + + grid_file diff --git a/doc/api/gridfs/index.rst b/doc/api/gridfs/index.rst index b81fbde782..190c561d05 100644 --- a/doc/api/gridfs/index.rst +++ b/doc/api/gridfs/index.rst @@ -8,7 +8,8 @@ Sub-modules: .. toctree:: - :maxdepth: 2 + :maxdepth: 3 + asynchronous/index errors grid_file diff --git a/doc/api/index.rst b/doc/api/index.rst index 437f2cc6a6..339f5843bf 100644 --- a/doc/api/index.rst +++ b/doc/api/index.rst @@ -3,7 +3,7 @@ API Documentation The PyMongo distribution contains three top-level packages for interacting with MongoDB. :mod:`bson` is an implementation of the -`BSON format `_, :mod:`pymongo` is a +`BSON format `_, :mod:`pymongo` is a full-featured driver for MongoDB, and :mod:`gridfs` is a set of tools for working with the `GridFS `_ storage diff --git a/doc/async-tutorial.rst b/doc/async-tutorial.rst index 2ccf011d8e..1884631ec3 100644 --- a/doc/async-tutorial.rst +++ b/doc/async-tutorial.rst @@ -385,7 +385,7 @@ Indexing Adding indexes can help accelerate certain queries and can also add additional functionality to querying and storing documents. In this example, we'll demonstrate how to create a `unique index -`_ on a key that rejects +`_ on a key that rejects documents whose value for that key already exists in the index. First, we'll need to create the index: @@ -420,3 +420,10 @@ the collection: DuplicateKeyError: E11000 duplicate key error index: test_database.profiles.$user_id_1 dup key: { : 212 } .. seealso:: The MongoDB documentation on `indexes `_ + +Task Cancellation +----------------- +`Cancelling `_ an asyncio Task +that is running a PyMongo operation is treated as a fatal interrupt. Any connections, cursors, and transactions +involved in a cancelled Task will be safely closed and cleaned up as part of the cancellation. If those resources are +also used elsewhere, attempting to utilize them after the cancellation will result in an error. diff --git a/doc/changelog.rst b/doc/changelog.rst index 1f3efb8ad0..2606947e12 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -1,7 +1,88 @@ Changelog ========= -Changes in Version 4.11.0 (YYYY/MM/DD) +Changes in Version 4.12.1 (2025/04/29) +-------------------------------------- + +Version 4.12.1 is a bug fix release. + +- Fixed a bug that could raise ``UnboundLocalError`` when creating asynchronous connections over SSL. +- Fixed a bug causing SRV hostname validation to fail when resolver and resolved hostnames are identical with three domain levels. +- Fixed a bug that caused direct use of ``pymongo.uri_parser`` to raise an ``AttributeError``. +- Fixed a bug where clients created with connect=False and a "mongodb+srv://" connection string + could cause public ``pymongo.MongoClient`` and ``pymongo.AsyncMongoClient`` attributes (topology_description, + nodes, address, primary, secondaries, arbiters) to incorrectly return a Database, leading to type + errors such as: "NotImplementedError: Database objects do not implement truth value testing or bool()". +- Fixed a bug where MongoDB cluster topology changes could cause asynchronous operations to take much longer to complete + due to holding the Topology lock while closing stale connections. +- Fixed a bug that would cause AsyncMongoClient to attempt to use PyOpenSSL when available, resulting in errors such as + "pymongo.errors.ServerSelectionTimeoutError: 'SSLContext' object has no attribute 'wrap_bio'". + +Issues Resolved +............... + +See the `PyMongo 4.12.1 release notes in JIRA`_ for the list of resolved issues +in this release. + +.. _PyMongo 4.12.1 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=43094 + +Changes in Version 4.12.0 (2025/04/08) +-------------------------------------- + +.. warning:: Driver support for MongoDB 4.0 reached end of life in April 2025. + PyMongo 4.12 will be the last release to support MongoDB 4.0. + +PyMongo 4.12 brings a number of changes including: + +- Support for configuring DEK cache lifetime via the ``key_expiration_ms`` argument to + :class:`~pymongo.encryption_options.AutoEncryptionOpts`. +- Support for $lookup in CSFLE and QE supported on MongoDB 8.1+. +- pymongocrypt>=1.13 is now required for :ref:`In-Use Encryption` support. +- Added :meth:`gridfs.asynchronous.grid_file.AsyncGridFSBucket.rename_by_name` and :meth:`gridfs.grid_file.GridFSBucket.rename_by_name` + for more performant renaming of a file with multiple revisions. +- Added :meth:`gridfs.asynchronous.grid_file.AsyncGridFSBucket.delete_by_name` and :meth:`gridfs.grid_file.GridFSBucket.delete_by_name` + for more performant deletion of a file with multiple revisions. +- AsyncMongoClient no longer performs DNS resolution for "mongodb+srv://" connection strings on creation. + To avoid blocking the asyncio loop, the resolution is now deferred until the client is first connected. +- Added index hinting support to the + :meth:`~pymongo.asynchronous.collection.AsyncCollection.distinct` and + :meth:`~pymongo.collection.Collection.distinct` commands. +- Deprecated the ``hedge`` parameter for + :class:`~pymongo.read_preferences.PrimaryPreferred`, + :class:`~pymongo.read_preferences.Secondary`, + :class:`~pymongo.read_preferences.SecondaryPreferred`, + :class:`~pymongo.read_preferences.Nearest`. Support for ``hedge`` will be removed in PyMongo 5.0. +- Removed PyOpenSSL support from the asynchronous API due to limitations of the CPython asyncio.Protocol SSL implementation. +- Allow valid SRV hostnames with less than 3 parts. + +Issues Resolved +............... + +See the `PyMongo 4.12 release notes in JIRA`_ for the list of resolved issues +in this release. + +.. _PyMongo 4.12 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=41916 + +Changes in Version 4.11.2 (2025/03/05) +-------------------------------------- + +Version 4.11.2 is a bug fix release. + +- Fixed a bug where :meth:`~pymongo.database.Database.command` would fail when attempting to run the bulkWrite command. + +Issues Resolved +............... + +See the `PyMongo 4.11.2 release notes in JIRA`_ for the list of resolved issues in this release. + +.. _PyMongo 4.11.2 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=42506 + +Changes in Version 4.11.1 (2025/02/10) +-------------------------------------- + +- Fixed support for prebuilt ``ppc64le`` and ``s390x`` wheels. + +Changes in Version 4.11.0 (2025/01/28) -------------------------------------- .. warning:: PyMongo 4.11 drops support for Python 3.8 and PyPy 3.9: Python 3.9+ or PyPy 3.10+ is now required. @@ -180,7 +261,7 @@ PyMongo 4.9 brings a number of improvements including: unction-as-a-service (FaaS) like AWS Lambda, Google Cloud Functions, and Microsoft Azure Functions. On some FaaS systems, there is a ``fork()`` operation at function startup. By delaying the connection to the first operation, we avoid a deadlock. See - `Is PyMongo Fork-Safe`_ for more information. + :ref:`pymongo-fork-safe` for more information. Issues Resolved @@ -189,7 +270,6 @@ Issues Resolved See the `PyMongo 4.9 release notes in JIRA`_ for the list of resolved issues in this release. -.. _Is PyMongo Fork-Safe : https://www.mongodb.com/docs/languages/python/pymongo-driver/current/faq/#is-pymongo-fork-safe- .. _PyMongo 4.9 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=39940 @@ -3033,7 +3113,7 @@ fixes. Highlights include: :class:`~gridfs.grid_file.GridOutCursor`. - Greatly improved :doc:`support for mod_wsgi ` when using PyMongo's C extensions. Read `Jesse's blog post - `_ for details. + `_ for details. - Improved C extension support for ARM little endian. Breaking changes @@ -3288,7 +3368,7 @@ Important New Features: - Support for mongos failover. - A new :meth:`~pymongo.collection.Collection.aggregate` method to support MongoDB's new `aggregation framework - `_. + `_. - Support for legacy Java and C# byte order when encoding and decoding UUIDs. - Support for connecting directly to an arbiter. @@ -3652,7 +3732,7 @@ Changes in Version 1.9 (2010/09/28) Version 1.9 adds a new package to the PyMongo distribution, :mod:`bson`. :mod:`bson` contains all of the `BSON -`_ encoding and decoding logic, and the BSON +`_ encoding and decoding logic, and the BSON types that were formerly in the :mod:`pymongo` package. The following modules have been renamed: @@ -3785,7 +3865,7 @@ Changes in Version 1.7 (2010/06/17) Version 1.7 is a recommended upgrade for all PyMongo users. The full release notes are below, and some more in depth discussion of the highlights is `here -`_. +`_. - no longer attempt to build the C extension on big-endian systems. - added :class:`~bson.min_key.MinKey` and @@ -3836,7 +3916,7 @@ The biggest change in version 1.6 is a complete re-implementation of :mod:`gridfs` with a lot of improvements over the old implementation. There are many details and examples of using the new API in `this blog post -`_. The +`_. The old API has been removed in this version, so existing code will need to be modified before upgrading to 1.6. diff --git a/doc/compatibility-policy.rst b/doc/compatibility-policy.rst index 834f86ce54..9721877d4d 100644 --- a/doc/compatibility-policy.rst +++ b/doc/compatibility-policy.rst @@ -52,7 +52,7 @@ deprecated PyMongo features. .. seealso:: The Python documentation on `the warnings module`_, and `the -W command line option`_. -.. _semantic versioning: http://semver.org/ +.. _semantic versioning: https://semver.org/ .. _DeprecationWarning: https://docs.python.org/3/library/exceptions.html#DeprecationWarning diff --git a/doc/conf.py b/doc/conf.py index f82c719361..c3ee5d8900 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -88,7 +88,7 @@ "https://github.com/mongodb/specifications/blob/master/source/server-discovery-and-monitoring/server-monitoring.md#requesting-an-immediate-check", "https://github.com/mongodb/libmongocrypt/blob/master/bindings/python/README.rst#installing-from-source", r"https://wiki.centos.org/[\w/]*", - r"http://sourceforge.net/", + r"https://sourceforge.net/", ] # -- Options for extensions ---------------------------------------------------- diff --git a/doc/examples/aggregation.rst b/doc/examples/aggregation.rst index 9b1a89fba7..e7e3df6ce1 100644 --- a/doc/examples/aggregation.rst +++ b/doc/examples/aggregation.rst @@ -87,4 +87,4 @@ you can add computed fields, create new virtual sub-objects, and extract sub-fields into the top-level of results. .. seealso:: The full documentation for MongoDB's `aggregation framework - `_ + `_ diff --git a/doc/examples/authentication.rst b/doc/examples/authentication.rst index a92222bafc..3f1137969d 100644 --- a/doc/examples/authentication.rst +++ b/doc/examples/authentication.rst @@ -191,7 +191,7 @@ Two extra ``authMechanismProperties`` are supported on Windows platforms: >>> uri = "mongodb://mongodbuser%40EXAMPLE.COM@example.com/?authMechanism=GSSAPI&authMechanismProperties=SERVICE_REALM:otherrealm" -.. _kerberos: http://pypi.python.org/pypi/kerberos +.. _kerberos: https://pypi.python.org/pypi/kerberos .. _pykerberos: https://pypi.python.org/pypi/pykerberos .. _winkerberos: https://pypi.python.org/pypi/winkerberos/ diff --git a/doc/examples/copydb.rst b/doc/examples/copydb.rst index b37677b5c2..c8026ba05f 100644 --- a/doc/examples/copydb.rst +++ b/doc/examples/copydb.rst @@ -67,7 +67,7 @@ Versions of PyMongo before 3.0 included a ``copy_database`` helper method, but it has been removed. .. _copyDatabase function in the mongo shell: - http://mongodb.com/docs/manual/reference/method/db.copyDatabase/ + https://mongodb.com/docs/manual/reference/method/db.copyDatabase/ .. _Copy a Database: https://www.mongodb.com/docs/database-tools/mongodump/mongodump-examples/#copy-and-clone-databases diff --git a/doc/examples/gevent.rst b/doc/examples/gevent.rst index 0ab41c1ec6..f62697d19f 100644 --- a/doc/examples/gevent.rst +++ b/doc/examples/gevent.rst @@ -1,7 +1,7 @@ Gevent ====== -PyMongo supports `Gevent `_. Simply call Gevent's +PyMongo supports `Gevent `_. Simply call Gevent's ``monkey.patch_all()`` before loading any other modules: .. code-block:: pycon diff --git a/doc/examples/gridfs.rst b/doc/examples/gridfs.rst index 5f40805d79..52920adbda 100644 --- a/doc/examples/gridfs.rst +++ b/doc/examples/gridfs.rst @@ -14,7 +14,7 @@ objects (e.g. files) in MongoDB. .. seealso:: The API docs for :mod:`gridfs`. .. seealso:: `This blog post - `_ + `_ for some motivation behind this API. Setup diff --git a/doc/examples/high_availability.rst b/doc/examples/high_availability.rst index 8f94aba074..80026153f8 100644 --- a/doc/examples/high_availability.rst +++ b/doc/examples/high_availability.rst @@ -2,7 +2,7 @@ High Availability and PyMongo ============================= PyMongo makes it easy to write highly available applications whether -you use a `single replica set `_ +you use a `single replica set `_ or a `large sharded cluster `_. @@ -10,17 +10,17 @@ Connecting to a Replica Set --------------------------- PyMongo makes working with `replica sets -`_ easy. Here we'll launch a new +`_ easy. Here we'll launch a new replica set and show how to handle both initialization and normal connections with PyMongo. -.. seealso:: The MongoDB documentation on `replication `_. +.. seealso:: The MongoDB documentation on `replication `_. Starting a Replica Set ~~~~~~~~~~~~~~~~~~~~~~ The main `replica set documentation -`_ contains extensive information +`_ contains extensive information about setting up a new replica set or migrating an existing MongoDB setup, be sure to check that out. Here, we'll just do the bare minimum to get a three node replica set setup locally. diff --git a/doc/examples/index.rst b/doc/examples/index.rst index ac450470ef..57682fa1af 100644 --- a/doc/examples/index.rst +++ b/doc/examples/index.rst @@ -6,7 +6,7 @@ of how to accomplish specific tasks with MongoDB and PyMongo. Unless otherwise noted, all examples assume that a MongoDB instance is running on the default host and port. Assuming you have `downloaded -and installed `_ +and installed `_ MongoDB, you can start it like so: .. code-block:: bash diff --git a/doc/examples/tls.rst b/doc/examples/tls.rst index 9241ac23e7..ee4d75027e 100644 --- a/doc/examples/tls.rst +++ b/doc/examples/tls.rst @@ -3,7 +3,7 @@ TLS/SSL and PyMongo PyMongo supports connecting to MongoDB over TLS/SSL. This guide covers the configuration options supported by PyMongo. See `the server documentation -`_ to configure +`_ to configure MongoDB. .. warning:: Industry best practices recommend, and some regulations require, diff --git a/doc/faq.rst b/doc/faq.rst index 73d0ec8966..7656481d89 100644 --- a/doc/faq.rst +++ b/doc/faq.rst @@ -53,9 +53,9 @@ a non `async-signal-safe`_ function. For examples of deadlocks or crashes that could occur see `PYTHON-3406`_. For a long but interesting read about the problems of Python locks in -multithreaded contexts with ``fork()``, see http://bugs.python.org/issue6721. +multithreaded contexts with ``fork()``, see https://bugs.python.org/issue6721. -.. _not fork-safe: http://bugs.python.org/issue6721 +.. _not fork-safe: https://bugs.python.org/issue6721 .. _OpenSSL: https://github.com/openssl/openssl/issues/19066 .. _fork(): https://man7.org/linux/man-pages/man2/fork.2.html .. _signal-safety(7): https://man7.org/linux/man-pages/man7/signal-safety.7.html @@ -174,10 +174,10 @@ Does PyMongo support asynchronous frameworks like Gevent, asyncio, Tornado, or T PyMongo fully supports :doc:`Gevent `. To use MongoDB with `asyncio `_ -or `Tornado `_, see the +or `Tornado `_, see the `Motor `_ project. -For `Twisted `_, see `TxMongo +For `Twisted `_, see `TxMongo `_. Its stated mission is to keep feature parity with PyMongo. @@ -381,7 +381,7 @@ Can you add attribute style access for documents? ------------------------------------------------- This request has come up a number of times but we've decided not to implement anything like this. The relevant `jira case -`_ has some information +`_ has some information about the decision, but here is a brief summary: 1. This will pollute the attribute namespace for documents, so could @@ -451,7 +451,7 @@ in Flask_ (other web frameworks are similar):: How can I use PyMongo from Django? ---------------------------------- -`Django `_ is a popular Python web +`Django `_ is a popular Python web framework. Django includes an ORM, :mod:`django.db`. Currently, there's no official MongoDB backend for Django. @@ -468,7 +468,7 @@ using just MongoDB, but most of what Django provides can still be used. One project which should make working with MongoDB and Django easier -is `mango `_. Mango is a set of +is `mango `_. Mango is a set of MongoDB backends for Django sessions and authentication (bypassing :mod:`django.db` entirely). diff --git a/doc/index.rst b/doc/index.rst index 079738314a..c7616ca795 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -9,7 +9,7 @@ PyMongo |release| Documentation Overview -------- **PyMongo** is a Python distribution containing tools for working with -`MongoDB `_, and is the recommended way to +`MongoDB `_, and is the recommended way to work with MongoDB from Python. This documentation attempts to explain everything you need to know to use **PyMongo**. @@ -81,7 +81,7 @@ Issues ------ All issues should be reported (and can be tracked / voted for / commented on) at the main `MongoDB JIRA bug tracker -`_, in the "Python Driver" +`_, in the "Python Driver" project. Feature Requests / Feedback @@ -94,7 +94,7 @@ Contributing **PyMongo** has a large :doc:`community ` and contributions are always encouraged. Contributions can be as simple as minor tweaks to this documentation. To contribute, fork the project on -`GitHub `_ and send a +`GitHub `_ and send a pull request. Changes diff --git a/doc/installation.rst b/doc/installation.rst index abda06db16..837cbf4d97 100644 --- a/doc/installation.rst +++ b/doc/installation.rst @@ -3,7 +3,7 @@ Installing / Upgrading .. highlight:: bash **PyMongo** is in the `Python Package Index -`_. +`_. .. warning:: **Do not install the "bson" package from pypi.** PyMongo comes with its own bson package; doing "pip install bson" @@ -12,7 +12,7 @@ Installing / Upgrading Installing with pip ------------------- -We recommend using `pip `_ +We recommend using `pip `_ to install pymongo on all platforms:: $ python3 -m pip install pymongo @@ -136,7 +136,7 @@ is a workaround:: # For some Python builds from python.org $ env ARCHFLAGS='-arch i386 -arch x86_64' python -m pip install pymongo -See `http://bugs.python.org/issue11623 `_ +See `https://bugs.python.org/issue11623 `_ for a more detailed explanation. **Lion (10.7) and newer** - PyMongo's C extensions can be built against diff --git a/doc/make.bat b/doc/make.bat index 2119f51099..aa1adb91a6 100644 --- a/doc/make.bat +++ b/doc/make.bat @@ -21,7 +21,7 @@ if errorlevel 9009 ( echo.may add the Sphinx directory to PATH. echo. echo.If you don't have Sphinx installed, grab it from - echo.http://sphinx-doc.org/ + echo.https://sphinx-doc.org/ exit /b 1 ) diff --git a/doc/tools.rst b/doc/tools.rst index 6dd0df8a4d..a3f167d024 100644 --- a/doc/tools.rst +++ b/doc/tools.rst @@ -31,7 +31,7 @@ MongoEngine layer on top of PyMongo. It allows you to define schemas for documents and query collections using syntax inspired by the Django ORM. The code is available on `GitHub - `_; for more information, see + `_; for more information, see the `tutorial `_. MincePy @@ -47,17 +47,15 @@ Ming `Ming `_ is a library that allows you to enforce schemas on a MongoDB database in your Python application. It was developed by `SourceForge - `_ in the course of their migration to - MongoDB. See the `introductory blog post - `_ - for more details. + `_ in the course of their migration to + MongoDB. MotorEngine `MotorEngine `_ is a port of MongoEngine to Motor, for asynchronous access with Tornado. It implements the same modeling APIs to be data-portable, meaning that a model defined in MongoEngine can be read in MotorEngine. The source is - `available on GitHub `_. + `available on GitHub `_. uMongo `uMongo `_ is a Python MongoDB ODM. @@ -67,6 +65,14 @@ uMongo mongomock. The source `is available on GitHub `_ +Django MongoDB Backend + `Django MongoDB Backend `_ is a + database backend library specifically made for Django. The integration takes + advantage of MongoDB's unique document model capabilities, which align + naturally with Django's philosophy of simplified data modeling and + reduced development complexity. The source is available + `on GitHub `_. + No longer maintained """""""""""""""""""" @@ -81,12 +87,12 @@ PyMODM `_. MongoKit - The `MongoKit `_ framework + The `MongoKit `_ framework is an ORM-like layer on top of PyMongo. There is also a MongoKit - `google group `_. + `google group `_. Minimongo - `minimongo `_ is a lightweight, + `minimongo `_ is a lightweight, pythonic interface to MongoDB. It retains pymongo's query and update API, and provides a number of additional features, including a simple document-oriented interface, connection pooling, index management, and @@ -94,15 +100,15 @@ Minimongo `_. Manga - `Manga `_ aims to be a simpler ORM-like + `Manga `_ aims to be a simpler ORM-like layer on top of PyMongo. The syntax for defining schema is inspired by the Django ORM, but Pymongo's query language is maintained. The source `is on - GitHub `_. + GitHub `_. Humongolus `Humongolus `_ is a lightweight ORM framework for Python and MongoDB. The name comes from the combination of - MongoDB and `Homunculus `_ (the + MongoDB and `Homunculus `_ (the concept of a miniature though fully formed human body). Humongolus allows you to create models/schemas with robust validation. It attempts to be as pythonic as possible and exposes the pymongo cursor objects whenever @@ -125,30 +131,30 @@ various Python frameworks and libraries. database backend for Django that completely integrates with its ORM. For more information `see the tutorial `_. -* `mango `_ provides MongoDB backends for +* `mango `_ provides MongoDB backends for Django sessions and authentication (bypassing :mod:`django.db` entirely). * `Django MongoEngine `_ is a MongoDB backend for Django, an `example: `_. For more information see ``_ -* `mongodb_beaker `_ is a +* `mongodb_beaker `_ is a project to enable using MongoDB as a backend for `beakers `_ caching / session system. - `The source is on GitHub `_. + `The source is on GitHub `_. * `Log4Mongo `_ is a flexible Python logging handler that can store logs in MongoDB using normal and capped collections. -* `MongoLog `_ is a Python logging +* `MongoLog `_ is a Python logging handler that stores logs in MongoDB using a capped collection. -* `rod.recipe.mongodb `_ is a +* `rod.recipe.mongodb `_ is a ZC Buildout recipe for downloading and installing MongoDB. -* `mongobox `_ is a tool to run a sandboxed +* `mongobox `_ is a tool to run a sandboxed MongoDB instance from within a python app. -* `Flask-MongoAlchemy `_ Add +* `Flask-MongoAlchemy `_ Add Flask support for MongoDB using MongoAlchemy. -* `Flask-MongoKit `_ Flask extension +* `Flask-MongoKit `_ Flask extension to better integrate MongoKit into Flask. -* `Flask-PyMongo `_ Flask-PyMongo +* `Flask-PyMongo `_ Flask-PyMongo bridges Flask and PyMongo. Alternative Drivers diff --git a/doc/tutorial.rst b/doc/tutorial.rst index e33936363d..46bde3035d 100644 --- a/doc/tutorial.rst +++ b/doc/tutorial.rst @@ -375,7 +375,7 @@ Indexing Adding indexes can help accelerate certain queries and can also add additional functionality to querying and storing documents. In this example, we'll demonstrate how to create a `unique index -`_ on a key that rejects +`_ on a key that rejects documents whose value for that key already exists in the index. First, we'll need to create the index: diff --git a/green_framework_test.py b/green_framework_test.py deleted file mode 100644 index 037d0279c3..0000000000 --- a/green_framework_test.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright 2015-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Test PyMongo with a variety of greenlet-based monkey-patching frameworks.""" -from __future__ import annotations - -import getopt -import sys - -import pytest - - -def run_gevent(): - """Prepare to run tests with Gevent. Can raise ImportError.""" - from gevent import monkey - - monkey.patch_all() - - -def run_eventlet(): - """Prepare to run tests with Eventlet. Can raise ImportError.""" - import eventlet - - # https://github.com/eventlet/eventlet/issues/401 - eventlet.sleep() - eventlet.monkey_patch() - - -FRAMEWORKS = { - "gevent": run_gevent, - "eventlet": run_eventlet, -} - - -def list_frameworks(): - """Tell the user what framework names are valid.""" - sys.stdout.write( - """Testable frameworks: %s - -Note that membership in this list means the framework can be tested with -PyMongo, not necessarily that it is officially supported. -""" - % ", ".join(sorted(FRAMEWORKS)) - ) - - -def run(framework_name, *args): - """Run tests with monkey-patching enabled. Can raise ImportError.""" - # Monkey-patch. - FRAMEWORKS[framework_name]() - - arg_list = list(args) - - # Never run async tests with a framework - if len(arg_list) <= 1: - arg_list.extend(["-m", "not default_async and default"]) - else: - for i in range(len(arg_list) - 1): - if "-m" in arg_list[i]: - arg_list[i + 1] = f"not default_async and {arg_list[i + 1]}" - - # Run the tests. - sys.exit(pytest.main(arg_list)) - - -def main(): - """Parse options and run tests.""" - usage = f"""python {sys.argv[0]} FRAMEWORK_NAME - -Test PyMongo with a variety of greenlet-based monkey-patching frameworks. See -python {sys.argv[0]} --help-frameworks.""" - - try: - opts, args = getopt.getopt(sys.argv[1:], "h", ["help", "help-frameworks"]) - except getopt.GetoptError as err: - print(str(err)) - print(usage) - sys.exit(2) - - for option_name, _ in opts: - if option_name in ("-h", "--help"): - print(usage) - sys.exit() - elif option_name == "--help-frameworks": - list_frameworks() - sys.exit() - else: - raise AssertionError("unhandled option") - - if not args: - print(usage) - sys.exit(1) - - if args[0] not in FRAMEWORKS: - print("%r is not a testable framework.\n" % args[0]) - list_frameworks() - sys.exit(1) - - run( - args[0], - *args[1:], # Framework name. - ) # Command line args to pytest, like what test to run. - - -if __name__ == "__main__": - main() diff --git a/gridfs/asynchronous/__init__.py b/gridfs/asynchronous/__init__.py new file mode 100644 index 0000000000..0826145b11 --- /dev/null +++ b/gridfs/asynchronous/__init__.py @@ -0,0 +1,42 @@ +# Copyright 2009-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GridFS is a specification for storing large objects in Mongo. + +The :mod:`gridfs` package is an implementation of GridFS on top of +:mod:`pymongo`, exposing a file-like interface. + +.. seealso:: The MongoDB documentation on `gridfs `_. +""" +from __future__ import annotations + +from gridfs.asynchronous.grid_file import ( + AsyncGridFS, + AsyncGridFSBucket, + AsyncGridIn, + AsyncGridOut, + AsyncGridOutCursor, +) +from gridfs.errors import NoFile +from gridfs.grid_file_shared import DEFAULT_CHUNK_SIZE + +__all__ = [ + "AsyncGridFS", + "AsyncGridFSBucket", + "NoFile", + "DEFAULT_CHUNK_SIZE", + "AsyncGridIn", + "AsyncGridOut", + "AsyncGridOutCursor", +] diff --git a/gridfs/asynchronous/grid_file.py b/gridfs/asynchronous/grid_file.py index a49d51d304..3c7d4ef0e9 100644 --- a/gridfs/asynchronous/grid_file.py +++ b/gridfs/asynchronous/grid_file.py @@ -100,7 +100,7 @@ def __init__(self, database: AsyncDatabase, collection: str = "fs"): .. seealso:: The MongoDB documentation on `gridfs `_. """ if not isinstance(database, AsyncDatabase): - raise TypeError("database must be an instance of Database") + raise TypeError(f"database must be an instance of Database, not {type(database)}") database = _clear_entity_type_registry(database) @@ -231,7 +231,7 @@ async def get_version( try: doc = await anext(cursor) return AsyncGridOut(self._collection, file_document=doc, session=session) - except StopIteration: + except StopAsyncIteration: raise NoFile("no version %d for filename %r" % (version, filename)) from None async def get_last_version( @@ -503,7 +503,7 @@ def __init__( .. seealso:: The MongoDB documentation on `gridfs `_. """ if not isinstance(db, AsyncDatabase): - raise TypeError("database must be an instance of AsyncDatabase") + raise TypeError(f"database must be an instance of AsyncDatabase, not {type(db)}") db = _clear_entity_type_registry(db) @@ -834,6 +834,35 @@ async def delete(self, file_id: Any, session: Optional[AsyncClientSession] = Non if not res.deleted_count: raise NoFile("no file could be deleted because none matched %s" % file_id) + @_csot.apply + async def delete_by_name( + self, filename: str, session: Optional[AsyncClientSession] = None + ) -> None: + """Given a filename, delete this stored file's files collection document(s) + and associated chunks from a GridFS bucket. + + For example:: + + my_db = AsyncMongoClient().test + fs = AsyncGridFSBucket(my_db) + await fs.upload_from_stream("test_file", "data I want to store!") + await fs.delete_by_name("test_file") + + Raises :exc:`~gridfs.errors.NoFile` if no file with the given filename exists. + + :param filename: The name of the file to be deleted. + :param session: a :class:`~pymongo.client_session.AsyncClientSession` + + .. versionadded:: 4.12 + """ + _disallow_transactions(session) + files = self._files.find({"filename": filename}, {"_id": 1}, session=session) + file_ids = [file["_id"] async for file in files] + res = await self._files.delete_many({"_id": {"$in": file_ids}}, session=session) + await self._chunks.delete_many({"files_id": {"$in": file_ids}}, session=session) + if not res.deleted_count: + raise NoFile(f"no file could be deleted because none matched filename {filename!r}") + def find(self, *args: Any, **kwargs: Any) -> AsyncGridOutCursor: """Find and return the files collection documents that match ``filter`` @@ -1021,6 +1050,35 @@ async def rename( "matched file_id %i" % (new_filename, file_id) ) + async def rename_by_name( + self, filename: str, new_filename: str, session: Optional[AsyncClientSession] = None + ) -> None: + """Renames the stored file with the specified filename. + + For example:: + + my_db = AsyncMongoClient().test + fs = AsyncGridFSBucket(my_db) + await fs.upload_from_stream("test_file", "data I want to store!") + await fs.rename_by_name("test_file", "new_test_name") + + Raises :exc:`~gridfs.errors.NoFile` if no file with the given filename exists. + + :param filename: The filename of the file to be renamed. + :param new_filename: The new name of the file. + :param session: a :class:`~pymongo.client_session.AsyncClientSession` + + .. versionadded:: 4.12 + """ + _disallow_transactions(session) + result = await self._files.update_many( + {"filename": filename}, {"$set": {"filename": new_filename}}, session=session + ) + if not result.matched_count: + raise NoFile( + f"no files could be renamed {new_filename!r} because none matched filename {filename!r}" + ) + class AsyncGridIn: """Class to write data to GridFS.""" @@ -1082,7 +1140,9 @@ def __init__( :attr:`~pymongo.collection.AsyncCollection.write_concern` """ if not isinstance(root_collection, AsyncCollection): - raise TypeError("root_collection must be an instance of AsyncCollection") + raise TypeError( + f"root_collection must be an instance of AsyncCollection, not {type(root_collection)}" + ) if not root_collection.write_concern.acknowledged: raise ConfigurationError("root_collection must use acknowledged write_concern") @@ -1299,11 +1359,8 @@ async def write(self, data: Any) -> None: raise ValueError("cannot write to a closed file") try: - if isinstance(data, AsyncGridOut): - read = data.read - else: - # file-like - read = data.read + # file-like + read = data.read except AttributeError: # string if not isinstance(data, (str, bytes)): @@ -1315,7 +1372,7 @@ async def write(self, data: Any) -> None: raise TypeError( "must specify an encoding for file in order to write str" ) from None - read = io.BytesIO(data).read # type: ignore[assignment] + read = io.BytesIO(data).read if inspect.iscoroutinefunction(read): await self._write_async(read) @@ -1329,15 +1386,15 @@ async def write(self, data: Any) -> None: except BaseException: await self.abort() raise - self._buffer.write(to_write) # type: ignore - if len(to_write) < space: # type: ignore + self._buffer.write(to_write) + if len(to_write) < space: return # EOF or incomplete await self._flush_buffer() to_write = read(self.chunk_size) - while to_write and len(to_write) == self.chunk_size: # type: ignore + while to_write and len(to_write) == self.chunk_size: await self._flush_data(to_write) to_write = read(self.chunk_size) - self._buffer.write(to_write) # type: ignore + self._buffer.write(to_write) async def _write_async(self, read: Any) -> None: if self._buffer.tell() > 0: @@ -1436,7 +1493,9 @@ def __init__( from the server. Metadata is fetched when first needed. """ if not isinstance(root_collection, AsyncCollection): - raise TypeError("root_collection must be an instance of AsyncCollection") + raise TypeError( + f"root_collection must be an instance of AsyncCollection, not {type(root_collection)}" + ) _disallow_transactions(session) root_collection = _clear_entity_type_registry(root_collection) diff --git a/gridfs/synchronous/grid_file.py b/gridfs/synchronous/grid_file.py index 655f05f57a..d0a4c7fc7f 100644 --- a/gridfs/synchronous/grid_file.py +++ b/gridfs/synchronous/grid_file.py @@ -100,7 +100,7 @@ def __init__(self, database: Database, collection: str = "fs"): .. seealso:: The MongoDB documentation on `gridfs `_. """ if not isinstance(database, Database): - raise TypeError("database must be an instance of Database") + raise TypeError(f"database must be an instance of Database, not {type(database)}") database = _clear_entity_type_registry(database) @@ -501,7 +501,7 @@ def __init__( .. seealso:: The MongoDB documentation on `gridfs `_. """ if not isinstance(db, Database): - raise TypeError("database must be an instance of Database") + raise TypeError(f"database must be an instance of Database, not {type(db)}") db = _clear_entity_type_registry(db) @@ -830,6 +830,33 @@ def delete(self, file_id: Any, session: Optional[ClientSession] = None) -> None: if not res.deleted_count: raise NoFile("no file could be deleted because none matched %s" % file_id) + @_csot.apply + def delete_by_name(self, filename: str, session: Optional[ClientSession] = None) -> None: + """Given a filename, delete this stored file's files collection document(s) + and associated chunks from a GridFS bucket. + + For example:: + + my_db = MongoClient().test + fs = GridFSBucket(my_db) + fs.upload_from_stream("test_file", "data I want to store!") + fs.delete_by_name("test_file") + + Raises :exc:`~gridfs.errors.NoFile` if no file with the given filename exists. + + :param filename: The name of the file to be deleted. + :param session: a :class:`~pymongo.client_session.ClientSession` + + .. versionadded:: 4.12 + """ + _disallow_transactions(session) + files = self._files.find({"filename": filename}, {"_id": 1}, session=session) + file_ids = [file["_id"] for file in files] + res = self._files.delete_many({"_id": {"$in": file_ids}}, session=session) + self._chunks.delete_many({"files_id": {"$in": file_ids}}, session=session) + if not res.deleted_count: + raise NoFile(f"no file could be deleted because none matched filename {filename!r}") + def find(self, *args: Any, **kwargs: Any) -> GridOutCursor: """Find and return the files collection documents that match ``filter`` @@ -1015,6 +1042,35 @@ def rename( "matched file_id %i" % (new_filename, file_id) ) + def rename_by_name( + self, filename: str, new_filename: str, session: Optional[ClientSession] = None + ) -> None: + """Renames the stored file with the specified filename. + + For example:: + + my_db = MongoClient().test + fs = GridFSBucket(my_db) + fs.upload_from_stream("test_file", "data I want to store!") + fs.rename_by_name("test_file", "new_test_name") + + Raises :exc:`~gridfs.errors.NoFile` if no file with the given filename exists. + + :param filename: The filename of the file to be renamed. + :param new_filename: The new name of the file. + :param session: a :class:`~pymongo.client_session.ClientSession` + + .. versionadded:: 4.12 + """ + _disallow_transactions(session) + result = self._files.update_many( + {"filename": filename}, {"$set": {"filename": new_filename}}, session=session + ) + if not result.matched_count: + raise NoFile( + f"no files could be renamed {new_filename!r} because none matched filename {filename!r}" + ) + class GridIn: """Class to write data to GridFS.""" @@ -1076,7 +1132,9 @@ def __init__( :attr:`~pymongo.collection.Collection.write_concern` """ if not isinstance(root_collection, Collection): - raise TypeError("root_collection must be an instance of Collection") + raise TypeError( + f"root_collection must be an instance of Collection, not {type(root_collection)}" + ) if not root_collection.write_concern.acknowledged: raise ConfigurationError("root_collection must use acknowledged write_concern") @@ -1289,11 +1347,8 @@ def write(self, data: Any) -> None: raise ValueError("cannot write to a closed file") try: - if isinstance(data, GridOut): - read = data.read - else: - # file-like - read = data.read + # file-like + read = data.read except AttributeError: # string if not isinstance(data, (str, bytes)): @@ -1305,7 +1360,7 @@ def write(self, data: Any) -> None: raise TypeError( "must specify an encoding for file in order to write str" ) from None - read = io.BytesIO(data).read # type: ignore[assignment] + read = io.BytesIO(data).read if inspect.iscoroutinefunction(read): self._write_async(read) @@ -1319,15 +1374,15 @@ def write(self, data: Any) -> None: except BaseException: self.abort() raise - self._buffer.write(to_write) # type: ignore - if len(to_write) < space: # type: ignore + self._buffer.write(to_write) + if len(to_write) < space: return # EOF or incomplete self._flush_buffer() to_write = read(self.chunk_size) - while to_write and len(to_write) == self.chunk_size: # type: ignore + while to_write and len(to_write) == self.chunk_size: self._flush_data(to_write) to_write = read(self.chunk_size) - self._buffer.write(to_write) # type: ignore + self._buffer.write(to_write) def _write_async(self, read: Any) -> None: if self._buffer.tell() > 0: @@ -1426,7 +1481,9 @@ def __init__( from the server. Metadata is fetched when first needed. """ if not isinstance(root_collection, Collection): - raise TypeError("root_collection must be an instance of Collection") + raise TypeError( + f"root_collection must be an instance of Collection, not {type(root_collection)}" + ) _disallow_transactions(session) root_collection = _clear_entity_type_registry(root_collection) diff --git a/justfile b/justfile index 8a076038a4..43aefb3f1a 100644 --- a/justfile +++ b/justfile @@ -1,7 +1,5 @@ # See https://just.systems/man/en/ for instructions set shell := ["bash", "-c"] -set dotenv-load -set dotenv-filename := "./.evergreen/scripts/env.sh" # Commonly used command segments. uv_run := "uv run --isolated --frozen " @@ -63,17 +61,21 @@ test *args="-v --durations=5 --maxfail=10": {{uv_run}} --extra test pytest {{args}} [group('test')] -test-mockupdb *args: - {{uv_run}} -v --extra test --group mockupdb pytest -m mockupdb {{args}} +run-tests *args: + bash ./.evergreen/run-tests.sh {{args}} [group('test')] -test-eg *args: - bash ./.evergreen/run-tests.sh {{args}} +setup-tests *args="": + bash .evergreen/scripts/setup-tests.sh {{args}} + +[group('test')] +teardown-tests: + bash .evergreen/scripts/teardown-tests.sh -[group('encryption')] -setup-encryption: - bash .evergreen/setup-encryption.sh +[group('server')] +run-server *args="": + bash .evergreen/scripts/run-server.sh {{args}} -[group('encryption')] -teardown-encryption: - bash .evergreen/teardown-encryption.sh +[group('server')] +stop-server: + bash .evergreen/scripts/stop-server.sh diff --git a/pymongo/__init__.py b/pymongo/__init__.py index 58f6ff338b..95eabef242 100644 --- a/pymongo/__init__.py +++ b/pymongo/__init__.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -55,7 +55,7 @@ GEO2D = "2d" """Index specifier for a 2-dimensional `geospatial index`_. -.. _geospatial index: http://mongodb.com/docs/manual/core/2d/ +.. _geospatial index: https://mongodb.com/docs/manual/core/2d/ """ GEOSPHERE = "2dsphere" @@ -63,7 +63,7 @@ .. versionadded:: 2.5 -.. _spherical geospatial index: http://mongodb.com/docs/manual/core/2dsphere/ +.. _spherical geospatial index: https://mongodb.com/docs/manual/core/2dsphere/ """ HASHED = "hashed" @@ -71,7 +71,7 @@ .. versionadded:: 2.5 -.. _hashed index: http://mongodb.com/docs/manual/core/index-hashed/ +.. _hashed index: https://mongodb.com/docs/manual/core/index-hashed/ """ TEXT = "text" @@ -83,7 +83,7 @@ .. versionadded:: 2.7.1 -.. _text index: http://mongodb.com/docs/manual/core/index-text/ +.. _text index: https://mongodb.com/docs/manual/core/index-text/ """ from pymongo import _csot @@ -105,6 +105,16 @@ from pymongo.synchronous.mongo_client import MongoClient from pymongo.write_concern import WriteConcern +# Public module compatibility imports +# isort: off +from pymongo import uri_parser # noqa: F401 +from pymongo import change_stream # noqa: F401 +from pymongo import client_session # noqa: F401 +from pymongo import collection # noqa: F401 +from pymongo import command_cursor # noqa: F401 +from pymongo import database # noqa: F401 +# isort: on + version = __version__ """Current version of PyMongo.""" @@ -160,7 +170,7 @@ def timeout(seconds: Optional[float]) -> ContextManager[None]: .. versionadded:: 4.2 """ if not isinstance(seconds, (int, float, type(None))): - raise TypeError("timeout must be None, an int, or a float") + raise TypeError(f"timeout must be None, an int, or a float, not {type(seconds)}") if seconds and seconds < 0: raise ValueError("timeout cannot be negative") if seconds is not None: diff --git a/pymongo/_asyncio_lock.py b/pymongo/_asyncio_lock.py index 669b0f63a7..a9c409d486 100644 --- a/pymongo/_asyncio_lock.py +++ b/pymongo/_asyncio_lock.py @@ -160,7 +160,7 @@ def release(self) -> None: self._locked = False self._wake_up_first() else: - raise RuntimeError("Lock is not acquired.") + raise RuntimeError("Lock is not acquired") def _wake_up_first(self) -> None: """Ensure that the first waiter will wake up.""" diff --git a/pymongo/_asyncio_task.py b/pymongo/_asyncio_task.py index 8e457763d9..7a528f027d 100644 --- a/pymongo/_asyncio_task.py +++ b/pymongo/_asyncio_task.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/_azure_helpers.py b/pymongo/_azure_helpers.py index 704c561cd5..8a7af0b407 100644 --- a/pymongo/_azure_helpers.py +++ b/pymongo/_azure_helpers.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -46,7 +46,7 @@ def _get_azure_response( try: data = json.loads(body) except Exception: - raise ValueError("Azure IMDS response must be in JSON format.") from None + raise ValueError("Azure IMDS response must be in JSON format") from None for key in ["access_token", "expires_in"]: if not data.get(key): diff --git a/pymongo/_client_bulk_shared.py b/pymongo/_client_bulk_shared.py index 649f1c6aa0..5814025566 100644 --- a/pymongo/_client_bulk_shared.py +++ b/pymongo/_client_bulk_shared.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/_cmessagemodule.c b/pymongo/_cmessagemodule.c index eb457b341c..a506863737 100644 --- a/pymongo/_cmessagemodule.c +++ b/pymongo/_cmessagemodule.c @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/_gcp_helpers.py b/pymongo/_gcp_helpers.py index d90f3cc217..7979d1e807 100644 --- a/pymongo/_gcp_helpers.py +++ b/pymongo/_gcp_helpers.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/_version.py b/pymongo/_version.py index 22972c5ce4..797910697c 100644 --- a/pymongo/_version.py +++ b/pymongo/_version.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -18,7 +18,7 @@ import re from typing import List, Tuple, Union -__version__ = "4.11" +__version__ = "4.12.1" def get_version_tuple(version: str) -> Tuple[Union[int, str], ...]: diff --git a/pymongo/asynchronous/aggregation.py b/pymongo/asynchronous/aggregation.py index 7684151897..daccd1bcb0 100644 --- a/pymongo/asynchronous/aggregation.py +++ b/pymongo/asynchronous/aggregation.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/asynchronous/auth.py b/pymongo/asynchronous/auth.py index b1e6d0125b..c1321f1d90 100644 --- a/pymongo/asynchronous/auth.py +++ b/pymongo/asynchronous/auth.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -161,7 +161,7 @@ def _password_digest(username: str, password: str) -> str: if len(password) == 0: raise ValueError("password can't be empty") if not isinstance(username, str): - raise TypeError("username must be an instance of str") + raise TypeError(f"username must be an instance of str, not {type(username)}") md5hash = hashlib.md5() # noqa: S324 data = f"{username}:mongo:{password}" diff --git a/pymongo/asynchronous/auth_aws.py b/pymongo/asynchronous/auth_aws.py index 9dcc625d19..210d306046 100644 --- a/pymongo/asynchronous/auth_aws.py +++ b/pymongo/asynchronous/auth_aws.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/asynchronous/auth_oidc.py b/pymongo/asynchronous/auth_oidc.py index f1c15045de..217c8104a2 100644 --- a/pymongo/asynchronous/auth_oidc.py +++ b/pymongo/asynchronous/auth_oidc.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -213,7 +213,9 @@ def _get_access_token(self) -> Optional[str]: ) resp = cb.fetch(context) if not isinstance(resp, OIDCCallbackResult): - raise ValueError("Callback result must be of type OIDCCallbackResult") + raise ValueError( + f"Callback result must be of type OIDCCallbackResult, not {type(resp)}" + ) self.refresh_token = resp.refresh_token self.access_token = resp.access_token self.token_gen_id += 1 diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index 6770d7b34e..ac514db98f 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -87,7 +87,7 @@ def __init__( self, collection: AsyncCollection[_DocumentType], ordered: bool, - bypass_document_validation: bool, + bypass_document_validation: Optional[bool], comment: Optional[str] = None, let: Optional[Any] = None, ) -> None: @@ -255,8 +255,8 @@ async def write_command( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.STARTED, + clientId=client._topology_settings._topology_id, command=cmd, commandName=next(iter(cmd)), databaseName=bwc.db_name, @@ -276,8 +276,8 @@ async def write_command( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.SUCCEEDED, + clientId=client._topology_settings._topology_id, durationMS=duration, reply=reply, commandName=next(iter(cmd)), @@ -302,8 +302,8 @@ async def write_command( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.FAILED, + clientId=client._topology_settings._topology_id, durationMS=duration, failure=failure, commandName=next(iter(cmd)), @@ -340,8 +340,8 @@ async def unack_write( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.STARTED, + clientId=client._topology_settings._topology_id, command=cmd, commandName=next(iter(cmd)), databaseName=bwc.db_name, @@ -366,8 +366,8 @@ async def unack_write( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.SUCCEEDED, + clientId=client._topology_settings._topology_id, durationMS=duration, reply=reply, commandName=next(iter(cmd)), @@ -393,8 +393,8 @@ async def unack_write( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.FAILED, + clientId=client._topology_settings._topology_id, durationMS=duration, failure=failure, commandName=next(iter(cmd)), @@ -516,8 +516,8 @@ async def _execute_command( if self.comment: cmd["comment"] = self.comment _csot.apply_write_concern(cmd, write_concern) - if self.bypass_doc_val: - cmd["bypassDocumentValidation"] = True + if self.bypass_doc_val is not None: + cmd["bypassDocumentValidation"] = self.bypass_doc_val if self.let is not None and run.op_type in (_DELETE, _UPDATE): cmd["let"] = self.let if session: diff --git a/pymongo/asynchronous/change_stream.py b/pymongo/asynchronous/change_stream.py index 719020c409..6c37f9d05f 100644 --- a/pymongo/asynchronous/change_stream.py +++ b/pymongo/asynchronous/change_stream.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -391,7 +391,8 @@ async def try_next(self) -> Optional[_DocumentType]: if not _resumable(exc) and not exc.timeout: await self.close() raise - except Exception: + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. + except BaseException: await self.close() raise diff --git a/pymongo/asynchronous/client_bulk.py b/pymongo/asynchronous/client_bulk.py index 45824256da..5f7ac013e9 100644 --- a/pymongo/asynchronous/client_bulk.py +++ b/pymongo/asynchronous/client_bulk.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -241,8 +241,8 @@ async def write_command( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.STARTED, + clientId=client._topology_settings._topology_id, command=cmd, commandName=next(iter(cmd)), databaseName=bwc.db_name, @@ -262,8 +262,8 @@ async def write_command( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.SUCCEEDED, + clientId=client._topology_settings._topology_id, durationMS=duration, reply=reply, commandName=next(iter(cmd)), @@ -289,8 +289,8 @@ async def write_command( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.FAILED, + clientId=client._topology_settings._topology_id, durationMS=duration, failure=failure, commandName=next(iter(cmd)), @@ -330,8 +330,8 @@ async def unack_write( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.STARTED, + clientId=client._topology_settings._topology_id, command=cmd, commandName=next(iter(cmd)), databaseName=bwc.db_name, @@ -356,8 +356,8 @@ async def unack_write( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.SUCCEEDED, + clientId=client._topology_settings._topology_id, durationMS=duration, reply=reply, commandName=next(iter(cmd)), @@ -383,8 +383,8 @@ async def unack_write( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.FAILED, + clientId=client._topology_settings._topology_id, durationMS=duration, failure=failure, commandName=next(iter(cmd)), diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index d80495d804..b808684dd4 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -21,7 +21,7 @@ .. code-block:: python - with client.start_session(causal_consistency=True) as session: + async with client.start_session(causal_consistency=True) as session: collection = client.db.collection await collection.update_one({"_id": 1}, {"$set": {"x": 10}}, session=session) secondary_c = collection.with_options(read_preference=ReadPreference.SECONDARY) @@ -53,8 +53,8 @@ orders = client.db.orders inventory = client.db.inventory - with client.start_session() as session: - async with session.start_transaction(): + async with client.start_session() as session: + async with await session.start_transaction(): await orders.insert_one({"sku": "abc123", "qty": 100}, session=session) await inventory.update_one( {"sku": "abc123", "qty": {"$gte": 100}}, @@ -62,7 +62,7 @@ session=session, ) -Upon normal completion of ``async with session.start_transaction()`` block, the +Upon normal completion of ``async with await session.start_transaction()`` block, the transaction automatically calls :meth:`AsyncClientSession.commit_transaction`. If the block exits with an exception, the transaction automatically calls :meth:`AsyncClientSession.abort_transaction`. @@ -113,7 +113,7 @@ .. code-block:: python # Each read using this session reads data from the same point in time. - with client.start_session(snapshot=True) as session: + async with client.start_session(snapshot=True) as session: order = await orders.find_one({"sku": "abc123"}, session=session) inventory = await inventory.find_one({"sku": "abc123"}, session=session) @@ -310,7 +310,9 @@ def __init__( ) if max_commit_time_ms is not None: if not isinstance(max_commit_time_ms, int): - raise TypeError("max_commit_time_ms must be an integer or None") + raise TypeError( + f"max_commit_time_ms must be an integer or None, not {type(max_commit_time_ms)}" + ) @property def read_concern(self) -> Optional[ReadConcern]: @@ -456,10 +458,10 @@ def _max_time_expired_error(exc: PyMongoError) -> bool: # From the transactions spec, all the retryable writes errors plus -# WriteConcernFailed. +# WriteConcernTimeout. _UNKNOWN_COMMIT_ERROR_CODES: frozenset = _RETRYABLE_ERROR_CODES | frozenset( [ - 64, # WriteConcernFailed + 64, # WriteConcernTimeout 50, # MaxTimeMSExpired ] ) @@ -617,7 +619,7 @@ async def callback(session): await inventory.update_one({"sku": "abc123", "qty": {"$gte": 100}}, {"$inc": {"qty": -100}}, session=session) - with client.start_session() as session: + async with client.start_session() as session: await session.with_transaction(callback) To pass arbitrary arguments to the ``callback``, wrap your callable @@ -626,7 +628,7 @@ async def callback(session): async def callback(session, custom_arg, custom_kwarg=None): # Transaction operations... - with client.start_session() as session: + async with client.start_session() as session: await session.with_transaction( lambda s: callback(s, "custom_arg", custom_kwarg=1)) @@ -695,7 +697,8 @@ async def callback(session, custom_arg, custom_kwarg=None): ) try: ret = await callback(self) - except Exception as exc: + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. + except BaseException as exc: if self.in_transaction: await self.abort_transaction() if ( @@ -902,7 +905,9 @@ def advance_cluster_time(self, cluster_time: Mapping[str, Any]) -> None: another `AsyncClientSession` instance. """ if not isinstance(cluster_time, _Mapping): - raise TypeError("cluster_time must be a subclass of collections.Mapping") + raise TypeError( + f"cluster_time must be a subclass of collections.Mapping, not {type(cluster_time)}" + ) if not isinstance(cluster_time.get("clusterTime"), Timestamp): raise ValueError("Invalid cluster_time") self._advance_cluster_time(cluster_time) @@ -923,7 +928,9 @@ def advance_operation_time(self, operation_time: Timestamp) -> None: another `AsyncClientSession` instance. """ if not isinstance(operation_time, Timestamp): - raise TypeError("operation_time must be an instance of bson.timestamp.Timestamp") + raise TypeError( + f"operation_time must be an instance of bson.timestamp.Timestamp, not {type(operation_time)}" + ) self._advance_operation_time(operation_time) def _process_response(self, reply: Mapping[str, Any]) -> None: diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index 9b73423627..7fb20b7ab3 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -228,7 +228,7 @@ def __init__( read_concern or database.read_concern, ) if not isinstance(name, str): - raise TypeError("name must be an instance of str") + raise TypeError(f"name must be an instance of str, not {type(name)}") from pymongo.asynchronous.database import AsyncDatabase if not isinstance(database, AsyncDatabase): @@ -701,7 +701,7 @@ async def bulk_write( self, requests: Sequence[_WriteOp[_DocumentType]], ordered: bool = True, - bypass_document_validation: bool = False, + bypass_document_validation: Optional[bool] = None, session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, let: Optional[Mapping] = None, @@ -800,7 +800,7 @@ async def _insert_one( ordered: bool, write_concern: WriteConcern, op_id: Optional[int], - bypass_doc_val: bool, + bypass_doc_val: Optional[bool], session: Optional[AsyncClientSession], comment: Optional[Any] = None, ) -> Any: @@ -814,8 +814,8 @@ async def _insert_one( async def _insert_command( session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool ) -> None: - if bypass_doc_val: - command["bypassDocumentValidation"] = True + if bypass_doc_val is not None: + command["bypassDocumentValidation"] = bypass_doc_val result = await conn.command( self._database.name, @@ -840,7 +840,7 @@ async def _insert_command( async def insert_one( self, document: Union[_DocumentType, RawBSONDocument], - bypass_document_validation: bool = False, + bypass_document_validation: Optional[bool] = None, session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, ) -> InsertOneResult: @@ -906,7 +906,7 @@ async def insert_many( self, documents: Iterable[Union[_DocumentType, RawBSONDocument]], ordered: bool = True, - bypass_document_validation: bool = False, + bypass_document_validation: Optional[bool] = None, session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, ) -> InsertManyResult: @@ -986,7 +986,7 @@ async def _update( write_concern: Optional[WriteConcern] = None, op_id: Optional[int] = None, ordered: bool = True, - bypass_doc_val: Optional[bool] = False, + bypass_doc_val: Optional[bool] = None, collation: Optional[_CollationIn] = None, array_filters: Optional[Sequence[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None, @@ -1041,8 +1041,8 @@ async def _update( if comment is not None: command["comment"] = comment # Update command. - if bypass_doc_val: - command["bypassDocumentValidation"] = True + if bypass_doc_val is not None: + command["bypassDocumentValidation"] = bypass_doc_val # The command result has to be published for APM unmodified # so we make a shallow copy here before adding updatedExisting. @@ -1082,7 +1082,7 @@ async def _update_retryable( write_concern: Optional[WriteConcern] = None, op_id: Optional[int] = None, ordered: bool = True, - bypass_doc_val: Optional[bool] = False, + bypass_doc_val: Optional[bool] = None, collation: Optional[_CollationIn] = None, array_filters: Optional[Sequence[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None, @@ -1128,7 +1128,7 @@ async def replace_one( filter: Mapping[str, Any], replacement: Mapping[str, Any], upsert: bool = False, - bypass_document_validation: bool = False, + bypass_document_validation: Optional[bool] = None, collation: Optional[_CollationIn] = None, hint: Optional[_IndexKeyHint] = None, session: Optional[AsyncClientSession] = None, @@ -1237,7 +1237,7 @@ async def update_one( filter: Mapping[str, Any], update: Union[Mapping[str, Any], _Pipeline], upsert: bool = False, - bypass_document_validation: bool = False, + bypass_document_validation: Optional[bool] = None, collation: Optional[_CollationIn] = None, array_filters: Optional[Sequence[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None, @@ -2475,7 +2475,7 @@ async def _drop_index( name = helpers_shared._gen_index_name(index_or_name) if not isinstance(name, str): - raise TypeError("index_or_name must be an instance of str or list") + raise TypeError(f"index_or_name must be an instance of str or list, not {type(name)}") cmd = {"dropIndexes": self._name, "index": name} cmd.update(kwargs) @@ -2948,6 +2948,7 @@ async def aggregate( returning aggregate results using a cursor. - `collation` (optional): An instance of :class:`~pymongo.collation.Collation`. + - `bypassDocumentValidation` (bool): If ``True``, allows the write to opt-out of document level validation. :return: A :class:`~pymongo.asynchronous.command_cursor.AsyncCommandCursor` over the result @@ -3078,7 +3079,7 @@ async def rename( """ if not isinstance(new_name, str): - raise TypeError("new_name must be an instance of str") + raise TypeError(f"new_name must be an instance of str, not {type(new_name)}") if not new_name or ".." in new_name: raise InvalidName("collection names cannot be empty") @@ -3111,6 +3112,7 @@ async def distinct( filter: Optional[Mapping[str, Any]] = None, session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, + hint: Optional[_IndexKeyHint] = None, **kwargs: Any, ) -> list: """Get a list of distinct values for `key` among all documents @@ -3138,8 +3140,15 @@ async def distinct( :class:`~pymongo.asynchronous.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. + :param hint: An index to use to support the query + predicate specified either by its string name, or in the same + format as passed to :meth:`~pymongo.asynchronous.collection.AsyncCollection.create_index` + (e.g. ``[('field', ASCENDING)]``). :param kwargs: See list of options above. + .. versionchanged:: 4.12 + Added ``hint`` parameter. + .. versionchanged:: 3.6 Added ``session`` parameter. @@ -3148,7 +3157,7 @@ async def distinct( """ if not isinstance(key, str): - raise TypeError("key must be an instance of str") + raise TypeError(f"key must be an instance of str, not {type(key)}") cmd = {"distinct": self._name, "key": key} if filter is not None: if "query" in kwargs: @@ -3158,6 +3167,10 @@ async def distinct( cmd.update(kwargs) if comment is not None: cmd["comment"] = comment + if hint is not None: + if not isinstance(hint, str): + hint = helpers_shared._index_document(hint) + cmd["hint"] = hint # type: ignore[assignment] async def _cmd( session: Optional[AsyncClientSession], @@ -3196,7 +3209,7 @@ async def _find_and_modify( common.validate_is_mapping("filter", filter) if not isinstance(return_document, bool): raise ValueError( - "return_document must be ReturnDocument.BEFORE or ReturnDocument.AFTER" + f"return_document must be ReturnDocument.BEFORE or ReturnDocument.AFTER, not {type(return_document)}" ) collation = validate_collation_or_none(kwargs.pop("collation", None)) cmd = {"findAndModify": self._name, "query": filter, "new": return_document} diff --git a/pymongo/asynchronous/command_cursor.py b/pymongo/asynchronous/command_cursor.py index 5a4559bd77..353c5e91c2 100644 --- a/pymongo/asynchronous/command_cursor.py +++ b/pymongo/asynchronous/command_cursor.py @@ -94,7 +94,9 @@ def __init__( self.batch_size(batch_size) if not isinstance(max_await_time_ms, int) and max_await_time_ms is not None: - raise TypeError("max_await_time_ms must be an integer or None") + raise TypeError( + f"max_await_time_ms must be an integer or None, not {type(max_await_time_ms)}" + ) def __del__(self) -> None: self._die_no_lock() @@ -115,7 +117,7 @@ def batch_size(self, batch_size: int) -> AsyncCommandCursor[_DocumentType]: :param batch_size: The size of each batch of results requested. """ if not isinstance(batch_size, int): - raise TypeError("batch_size must be an integer") + raise TypeError(f"batch_size must be an integer, not {type(batch_size)}") if batch_size < 0: raise ValueError("batch_size must be >= 0") diff --git a/pymongo/asynchronous/cursor.py b/pymongo/asynchronous/cursor.py index 8193e53282..1b25bf4ee8 100644 --- a/pymongo/asynchronous/cursor.py +++ b/pymongo/asynchronous/cursor.py @@ -146,9 +146,9 @@ def __init__( spec: Mapping[str, Any] = filter or {} validate_is_mapping("filter", spec) if not isinstance(skip, int): - raise TypeError("skip must be an instance of int") + raise TypeError(f"skip must be an instance of int, not {type(skip)}") if not isinstance(limit, int): - raise TypeError("limit must be an instance of int") + raise TypeError(f"limit must be an instance of int, not {type(limit)}") validate_boolean("no_cursor_timeout", no_cursor_timeout) if no_cursor_timeout and not self._explicit_session: warnings.warn( @@ -171,7 +171,7 @@ def __init__( validate_boolean("allow_partial_results", allow_partial_results) validate_boolean("oplog_replay", oplog_replay) if not isinstance(batch_size, int): - raise TypeError("batch_size must be an integer") + raise TypeError(f"batch_size must be an integer, not {type(batch_size)}") if batch_size < 0: raise ValueError("batch_size must be >= 0") # Only set if allow_disk_use is provided by the user, else None. @@ -388,7 +388,7 @@ async def add_option(self, mask: int) -> AsyncCursor[_DocumentType]: cursor.add_option(2) """ if not isinstance(mask, int): - raise TypeError("mask must be an int") + raise TypeError(f"mask must be an int, not {type(mask)}") self._check_okay_to_chain() if mask & _QUERY_OPTIONS["exhaust"]: @@ -408,7 +408,7 @@ def remove_option(self, mask: int) -> AsyncCursor[_DocumentType]: cursor.remove_option(2) """ if not isinstance(mask, int): - raise TypeError("mask must be an int") + raise TypeError(f"mask must be an int, not {type(mask)}") self._check_okay_to_chain() if mask & _QUERY_OPTIONS["exhaust"]: @@ -432,7 +432,7 @@ def allow_disk_use(self, allow_disk_use: bool) -> AsyncCursor[_DocumentType]: .. versionadded:: 3.11 """ if not isinstance(allow_disk_use, bool): - raise TypeError("allow_disk_use must be a bool") + raise TypeError(f"allow_disk_use must be a bool, not {type(allow_disk_use)}") self._check_okay_to_chain() self._allow_disk_use = allow_disk_use @@ -451,7 +451,7 @@ def limit(self, limit: int) -> AsyncCursor[_DocumentType]: .. seealso:: The MongoDB documentation on `limit `_. """ if not isinstance(limit, int): - raise TypeError("limit must be an integer") + raise TypeError(f"limit must be an integer, not {type(limit)}") if self._exhaust: raise InvalidOperation("Can't use limit and exhaust together.") self._check_okay_to_chain() @@ -479,7 +479,7 @@ def batch_size(self, batch_size: int) -> AsyncCursor[_DocumentType]: :param batch_size: The size of each batch of results requested. """ if not isinstance(batch_size, int): - raise TypeError("batch_size must be an integer") + raise TypeError(f"batch_size must be an integer, not {type(batch_size)}") if batch_size < 0: raise ValueError("batch_size must be >= 0") self._check_okay_to_chain() @@ -499,7 +499,7 @@ def skip(self, skip: int) -> AsyncCursor[_DocumentType]: :param skip: the number of results to skip """ if not isinstance(skip, int): - raise TypeError("skip must be an integer") + raise TypeError(f"skip must be an integer, not {type(skip)}") if skip < 0: raise ValueError("skip must be >= 0") self._check_okay_to_chain() @@ -520,7 +520,7 @@ def max_time_ms(self, max_time_ms: Optional[int]) -> AsyncCursor[_DocumentType]: :param max_time_ms: the time limit after which the operation is aborted """ if not isinstance(max_time_ms, int) and max_time_ms is not None: - raise TypeError("max_time_ms must be an integer or None") + raise TypeError(f"max_time_ms must be an integer or None, not {type(max_time_ms)}") self._check_okay_to_chain() self._max_time_ms = max_time_ms @@ -543,7 +543,9 @@ def max_await_time_ms(self, max_await_time_ms: Optional[int]) -> AsyncCursor[_Do .. versionadded:: 3.2 """ if not isinstance(max_await_time_ms, int) and max_await_time_ms is not None: - raise TypeError("max_await_time_ms must be an integer or None") + raise TypeError( + f"max_await_time_ms must be an integer or None, not {type(max_await_time_ms)}" + ) self._check_okay_to_chain() # Ignore max_await_time_ms if not tailable or await_data is False. @@ -679,7 +681,7 @@ def max(self, spec: _Sort) -> AsyncCursor[_DocumentType]: .. versionadded:: 2.7 """ if not isinstance(spec, (list, tuple)): - raise TypeError("spec must be an instance of list or tuple") + raise TypeError(f"spec must be an instance of list or tuple, not {type(spec)}") self._check_okay_to_chain() self._max = dict(spec) @@ -701,7 +703,7 @@ def min(self, spec: _Sort) -> AsyncCursor[_DocumentType]: .. versionadded:: 2.7 """ if not isinstance(spec, (list, tuple)): - raise TypeError("spec must be an instance of list or tuple") + raise TypeError(f"spec must be an instance of list or tuple, not {type(spec)}") self._check_okay_to_chain() self._min = dict(spec) @@ -1124,7 +1126,8 @@ async def _send_message(self, operation: Union[_Query, _GetMore]) -> None: self._killed = True await self.close() raise - except Exception: + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. + except BaseException: await self.close() raise diff --git a/pymongo/asynchronous/database.py b/pymongo/asynchronous/database.py index 98a0a6ff3b..d0089eb4ee 100644 --- a/pymongo/asynchronous/database.py +++ b/pymongo/asynchronous/database.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -122,7 +122,7 @@ def __init__( from pymongo.asynchronous.mongo_client import AsyncMongoClient if not isinstance(name, str): - raise TypeError("name must be an instance of str") + raise TypeError(f"name must be an instance of str, not {type(name)}") if not isinstance(client, AsyncMongoClient): # This is for compatibility with mocked and subclassed types, such as in Motor. @@ -1310,7 +1310,7 @@ async def drop_collection( name = name.name if not isinstance(name, str): - raise TypeError("name_or_collection must be an instance of str") + raise TypeError(f"name_or_collection must be an instance of str, not {type(name)}") encrypted_fields = await self._get_encrypted_fields( {"encryptedFields": encrypted_fields}, name, @@ -1374,7 +1374,9 @@ async def validate_collection( name = name.name if not isinstance(name, str): - raise TypeError("name_or_collection must be an instance of str or AsyncCollection") + raise TypeError( + f"name_or_collection must be an instance of str or AsyncCollection, not {type(name)}" + ) cmd = {"validate": name, "scandata": scandata, "full": full} if comment is not None: cmd["comment"] = comment diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index 98ab68527c..9b0757b1a5 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -64,11 +64,6 @@ from pymongo.asynchronous.cursor import AsyncCursor from pymongo.asynchronous.database import AsyncDatabase from pymongo.asynchronous.mongo_client import AsyncMongoClient -from pymongo.asynchronous.pool import ( - _configured_socket, - _get_timeout_details, - _raise_connection_failure, -) from pymongo.common import CONNECT_TIMEOUT from pymongo.daemon import _spawn_daemon from pymongo.encryption_options import AutoEncryptionOpts, RangeOpts @@ -80,14 +75,19 @@ NetworkTimeout, ServerSelectionTimeoutError, ) -from pymongo.network_layer import BLOCKING_IO_ERRORS, async_sendall +from pymongo.network_layer import async_socket_sendall from pymongo.operations import UpdateOne from pymongo.pool_options import PoolOptions +from pymongo.pool_shared import ( + _async_configured_socket, + _get_timeout_details, + _raise_connection_failure, +) from pymongo.read_concern import ReadConcern from pymongo.results import BulkWriteResult, DeleteResult -from pymongo.ssl_support import get_ssl_context +from pymongo.ssl_support import BLOCKING_IO_ERRORS, get_ssl_context from pymongo.typings import _DocumentType, _DocumentTypeArg -from pymongo.uri_parser import parse_host +from pymongo.uri_parser_shared import _parse_kms_tls_options, parse_host from pymongo.write_concern import WriteConcern if TYPE_CHECKING: @@ -113,7 +113,7 @@ async def _connect_kms(address: _Address, opts: PoolOptions) -> Union[socket.socket, _sslConn]: try: - return await _configured_socket(address, opts) + return await _async_configured_socket(address, opts) except Exception as exc: _raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts)) @@ -127,8 +127,6 @@ def _wrap_encryption_errors() -> Iterator[None]: # BSON encoding/decoding errors are unrelated to encryption so # we should propagate them unchanged. raise - except asyncio.CancelledError: - raise except Exception as exc: raise EncryptionError(exc) from exc @@ -159,6 +157,7 @@ def __init__( self.mongocryptd_client = mongocryptd_client self.opts = opts self._spawned = False + self._kms_ssl_contexts = opts._kms_ssl_contexts(_IS_SYNC) async def kms_request(self, kms_context: MongoCryptKmsContext) -> None: """Complete a KMS request. @@ -170,7 +169,7 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None: endpoint = kms_context.endpoint message = kms_context.message provider = kms_context.kms_provider - ctx = self.opts._kms_ssl_contexts.get(provider) + ctx = self._kms_ssl_contexts.get(provider) if ctx is None: # Enable strict certificate verification, OCSP, match hostname, and # SNI using the system default CA certificates. @@ -182,6 +181,7 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None: False, # allow_invalid_certificates False, # allow_invalid_hostnames False, # disable_ocsp_endpoint_check + _IS_SYNC, ) # CSOT: set timeout for socket creation. connect_timeout = max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0.001) @@ -198,7 +198,7 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None: try: conn = await _connect_kms(address, opts) try: - await async_sendall(conn, message) + await async_socket_sendall(conn, message) while kms_context.bytes_needed > 0: # CSOT: update timeout. conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) @@ -244,7 +244,7 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None: ) raise exc from final_err - async def collection_info(self, database: str, filter: bytes) -> Optional[bytes]: + async def collection_info(self, database: str, filter: bytes) -> Optional[list[bytes]]: """Get the collection info for a namespace. The returned collection info is passed to libmongocrypt which reads @@ -253,14 +253,12 @@ async def collection_info(self, database: str, filter: bytes) -> Optional[bytes] :param database: The database on which to run listCollections. :param filter: The filter to pass to listCollections. - :return: The first document from the listCollections command response as BSON. + :return: All documents from the listCollections command response as BSON. """ async with await self.client_ref()[database].list_collections( filter=RawBSONDocument(filter) ) as cursor: - async for doc in cursor: - return _dict_to_bson(doc, False, _DATA_KEY_OPTS) - return None + return [_dict_to_bson(doc, False, _DATA_KEY_OPTS) async for doc in cursor] def spawn(self) -> None: """Spawn mongocryptd. @@ -322,7 +320,9 @@ async def insert_data_key(self, data_key: bytes) -> Binary: raw_doc = RawBSONDocument(data_key, _KEY_VAULT_OPTS) data_key_id = raw_doc.get("_id") if not isinstance(data_key_id, Binary) or data_key_id.subtype != UUID_SUBTYPE: - raise TypeError("data_key _id must be Binary with a UUID subtype") + raise TypeError( + f"data_key _id must be Binary with a UUID subtype, not {type(data_key_id)}" + ) assert self.key_vault_coll is not None await self.key_vault_coll.insert_one(raw_doc) @@ -398,6 +398,8 @@ def __init__(self, client: AsyncMongoClient[_DocumentTypeArg], opts: AutoEncrypt encrypted_fields_map = _dict_to_bson(opts._encrypted_fields_map, False, _DATA_KEY_OPTS) self._bypass_auto_encryption = opts._bypass_auto_encryption self._internal_client = None + # parsing kms_ssl_contexts here so that parsing errors will be raised before internal clients are created + opts._kms_ssl_contexts(_IS_SYNC) def _get_internal_client( encrypter: _Encrypter, mongo_client: AsyncMongoClient[_DocumentTypeArg] @@ -445,6 +447,7 @@ def _get_internal_client( bypass_encryption=opts._bypass_auto_encryption, encrypted_fields_map=encrypted_fields_map, bypass_query_analysis=opts._bypass_query_analysis, + key_expiration_ms=opts._key_expiration_ms, ), ) self._closed = False @@ -547,11 +550,10 @@ class QueryType(str, enum.Enum): def _create_mongocrypt_options(**kwargs: Any) -> MongoCryptOptions: - opts = MongoCryptOptions(**kwargs) - # Opt into range V2 encryption. - if hasattr(opts, "enable_range_v2"): - opts.enable_range_v2 = True - return opts + # For compat with pymongocrypt <1.13, avoid setting the default key_expiration_ms. + if kwargs.get("key_expiration_ms") is None: + kwargs.pop("key_expiration_ms", None) + return MongoCryptOptions(**kwargs, enable_multiple_collinfo=True) class AsyncClientEncryption(Generic[_DocumentType]): @@ -564,6 +566,7 @@ def __init__( key_vault_client: AsyncMongoClient[_DocumentTypeArg], codec_options: CodecOptions[_DocumentTypeArg], kms_tls_options: Optional[Mapping[str, Any]] = None, + key_expiration_ms: Optional[int] = None, ) -> None: """Explicit client-side field level encryption. @@ -630,7 +633,12 @@ def __init__( Or to supply a client certificate:: kms_tls_options={'kmip': {'tlsCertificateKeyFile': 'client.pem'}} + :param key_expiration_ms: The cache expiration time for data encryption keys. + Defaults to ``None`` which defers to libmongocrypt's default which is currently 60000. + Set to 0 to disable key expiration. + .. versionchanged:: 4.12 + Added the `key_expiration_ms` parameter. .. versionchanged:: 4.0 Added the `kms_tls_options` parameter and the "kmip" KMS provider. @@ -644,7 +652,9 @@ def __init__( ) if not isinstance(codec_options, CodecOptions): - raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions") + raise TypeError( + f"codec_options must be an instance of bson.codec_options.CodecOptions, not {type(codec_options)}" + ) if not isinstance(key_vault_client, AsyncMongoClient): # This is for compatibility with mocked and subclassed types, such as in Motor. @@ -664,14 +674,20 @@ def __init__( key_vault_coll = key_vault_client[db][coll] opts = AutoEncryptionOpts( - kms_providers, key_vault_namespace, kms_tls_options=kms_tls_options + kms_providers, + key_vault_namespace, + kms_tls_options=kms_tls_options, + key_expiration_ms=key_expiration_ms, ) + self._kms_ssl_contexts = _parse_kms_tls_options(opts._kms_tls_options, _IS_SYNC) self._io_callbacks: Optional[_EncryptionIO] = _EncryptionIO( None, key_vault_coll, None, opts ) self._encryption = AsyncExplicitEncrypter( self._io_callbacks, - _create_mongocrypt_options(kms_providers=kms_providers, schema_map=None), + _create_mongocrypt_options( + kms_providers=kms_providers, schema_map=None, key_expiration_ms=key_expiration_ms + ), ) # Use the same key vault collection as the callback. assert self._io_callbacks.key_vault_coll is not None @@ -698,6 +714,7 @@ async def create_encrypted_collection( creation. :class:`~pymongo.errors.EncryptionError` will be raised if the collection already exists. + :param database: the database to create the collection :param name: the name of the collection to create :param encrypted_fields: Document that describes the encrypted fields for Queryable Encryption. The "keyId" may be set to ``None`` to auto-generate the data keys. For example: @@ -762,8 +779,6 @@ async def create_encrypted_collection( await database.create_collection(name=name, **kwargs), encrypted_fields, ) - except asyncio.CancelledError: - raise except Exception as exc: raise EncryptedCollectionError(exc, encrypted_fields) from exc diff --git a/pymongo/asynchronous/helpers.py b/pymongo/asynchronous/helpers.py index d519e8749c..88b710345b 100644 --- a/pymongo/asynchronous/helpers.py +++ b/pymongo/asynchronous/helpers.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 1600e50628..a236b21348 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -44,6 +44,7 @@ AsyncContextManager, AsyncGenerator, Callable, + Collection, Coroutine, FrozenSet, Generic, @@ -60,8 +61,8 @@ from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry from bson.timestamp import Timestamp -from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser -from pymongo.asynchronous import client_session, database +from pymongo import _csot, common, helpers_shared, periodic_executor +from pymongo.asynchronous import client_session, database, uri_parser from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream from pymongo.asynchronous.client_bulk import _AsyncClientBulk from pymongo.asynchronous.client_session import _EmptyServerSession @@ -88,9 +89,15 @@ _async_create_lock, _release_locks, ) -from pymongo.logger import _CLIENT_LOGGER, _log_or_warn +from pymongo.logger import ( + _CLIENT_LOGGER, + _COMMAND_LOGGER, + _debug_log, + _log_client_error, + _log_or_warn, +) from pymongo.message import _CursorAddress, _GetMore, _Query -from pymongo.monitoring import ConnectionClosedReason +from pymongo.monitoring import ConnectionClosedReason, _EventListeners from pymongo.operations import ( DeleteMany, DeleteOne, @@ -102,6 +109,7 @@ ) from pymongo.read_preferences import ReadPreference, _ServerMode from pymongo.results import ClientBulkWriteResult +from pymongo.server_description import ServerDescription from pymongo.server_selectors import writable_server_selector from pymongo.server_type import SERVER_TYPE from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription @@ -113,11 +121,14 @@ _DocumentTypeArg, _Pipeline, ) -from pymongo.uri_parser import ( +from pymongo.uri_parser_shared import ( + SRV_SCHEME, _check_options, _handle_option_deprecations, _handle_security_options, _normalize_options, + _validate_uri, + split_hosts, ) from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern @@ -128,6 +139,7 @@ from pymongo.asynchronous.bulk import _AsyncBulk from pymongo.asynchronous.client_session import AsyncClientSession, _ServerSession from pymongo.asynchronous.cursor import _ConnectionManager + from pymongo.asynchronous.encryption import _Encrypter from pymongo.asynchronous.pool import AsyncConnection from pymongo.asynchronous.server import Server from pymongo.read_concern import ReadConcern @@ -192,7 +204,7 @@ def __init__( execute. The `host` parameter can be a full `mongodb URI - `_, in addition to + `_, in addition to a simple hostname. It can also be a list of hostnames but no more than one URI. Any port specified in the host string(s) will override the `port` parameter. For username and @@ -276,7 +288,9 @@ def __init__( :param type_registry: instance of :class:`~bson.codec_options.TypeRegistry` to enable encoding and decoding of custom types. - :param datetime_conversion: Specifies how UTC datetimes should be decoded + :param kwargs: **Additional optional parameters available as keyword arguments:** + + - `datetime_conversion` (optional): Specifies how UTC datetimes should be decoded within BSON. Valid options include 'datetime_ms' to return as a DatetimeMS, 'datetime' to return as a datetime.datetime and raising a ValueError for out-of-range values, 'datetime_auto' to @@ -284,9 +298,6 @@ def __init__( out-of-range and 'datetime_clamp' to clamp to the minimum and maximum possible datetimes. Defaults to 'datetime'. See :ref:`handling-out-of-range-datetimes` for details. - - | **Other optional parameters can be passed as keyword arguments:** - - `directConnection` (optional): if ``True``, forces this client to connect directly to the specified MongoDB host as a standalone. If ``false``, the client connects to the entire replica set of @@ -750,7 +761,13 @@ def __init__( if port is None: port = self.PORT if not isinstance(port, int): - raise TypeError("port must be an instance of int") + raise TypeError(f"port must be an instance of int, not {type(port)}") + self._host = host + self._port = port + self._topology: Topology = None # type: ignore[assignment] + self._timeout: float | None = None + self._topology_settings: TopologySettings = None # type: ignore[assignment] + self._event_listeners: _EventListeners | None = None # _pool_class, _monitor_class, and _condition_class are for deep # customization of PyMongo, e.g. Motor. @@ -761,8 +778,10 @@ def __init__( # Parse options passed as kwargs. keyword_opts = common._CaseInsensitiveDictionary(kwargs) keyword_opts["document_class"] = doc_class + self._resolve_srv_info: dict[str, Any] = {"keyword_opts": keyword_opts} - seeds = set() + self._seeds = set() + is_srv = False username = None password = None dbase = None @@ -770,41 +789,34 @@ def __init__( fqdn = None srv_service_name = keyword_opts.get("srvservicename") srv_max_hosts = keyword_opts.get("srvmaxhosts") - if len([h for h in host if "/" in h]) > 1: + if len([h for h in self._host if "/" in h]) > 1: raise ConfigurationError("host must not contain multiple MongoDB URIs") - for entity in host: + for entity in self._host: # A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/' # it must be a URI, # https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names if "/" in entity: - # Determine connection timeout from kwargs. - timeout = keyword_opts.get("connecttimeoutms") - if timeout is not None: - timeout = common.validate_timeout_or_none_or_zero( - keyword_opts.cased_key("connecttimeoutms"), timeout - ) - res = uri_parser.parse_uri( + res = _validate_uri( entity, port, validate=True, warn=True, normalize=False, - connect_timeout=timeout, - srv_service_name=srv_service_name, srv_max_hosts=srv_max_hosts, ) - seeds.update(res["nodelist"]) + is_srv = entity.startswith(SRV_SCHEME) + self._seeds.update(res["nodelist"]) username = res["username"] or username password = res["password"] or password dbase = res["database"] or dbase opts = res["options"] fqdn = res["fqdn"] else: - seeds.update(uri_parser.split_hosts(entity, port)) - if not seeds: + self._seeds.update(split_hosts(entity, self._port)) + if not self._seeds: raise ConfigurationError("need to specify at least one host") - for hostname in [node[0] for node in seeds]: + for hostname in [node[0] for node in self._seeds]: if _detect_external_db(hostname): break @@ -821,80 +833,180 @@ def __init__( keyword_opts["tz_aware"] = tz_aware keyword_opts["connect"] = connect - # Handle deprecated options in kwarg options. - keyword_opts = _handle_option_deprecations(keyword_opts) - # Validate kwarg options. - keyword_opts = common._CaseInsensitiveDictionary( - dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items()) - ) - - # Override connection string options with kwarg options. - opts.update(keyword_opts) + opts = self._validate_kwargs_and_update_opts(keyword_opts, opts) if srv_service_name is None: srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME) srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts") - # Handle security-option conflicts in combined options. - opts = _handle_security_options(opts) - # Normalize combined options. - opts = _normalize_options(opts) - _check_options(seeds, opts) + opts = self._normalize_and_validate_options(opts, self._seeds) # Username and password passed as kwargs override user info in URI. username = opts.get("username", username) password = opts.get("password", password) - self._options = options = ClientOptions(username, password, dbase, opts, _IS_SYNC) + self._options = ClientOptions(username, password, dbase, opts, _IS_SYNC) self._default_database_name = dbase self._lock = _async_create_lock() self._kill_cursors_queue: list = [] - self._event_listeners = options.pool_options._event_listeners - super().__init__( - options.codec_options, - options.read_preference, - options.write_concern, - options.read_concern, + self._encrypter: Optional[_Encrypter] = None + + self._resolve_srv_info.update( + { + "is_srv": is_srv, + "username": username, + "password": password, + "dbase": dbase, + "seeds": self._seeds, + "fqdn": fqdn, + "srv_service_name": srv_service_name, + "pool_class": pool_class, + "monitor_class": monitor_class, + "condition_class": condition_class, + } ) - self._topology_settings = TopologySettings( - seeds=seeds, - replica_set_name=options.replica_set_name, - pool_class=pool_class, - pool_options=options.pool_options, - monitor_class=monitor_class, - condition_class=condition_class, - local_threshold_ms=options.local_threshold_ms, - server_selection_timeout=options.server_selection_timeout, - server_selector=options.server_selector, - heartbeat_frequency=options.heartbeat_frequency, - fqdn=fqdn, - direct_connection=options.direct_connection, - load_balanced=options.load_balanced, - srv_service_name=srv_service_name, - srv_max_hosts=srv_max_hosts, - server_monitoring_mode=options.server_monitoring_mode, + super().__init__( + self._options.codec_options, + self._options.read_preference, + self._options.write_concern, + self._options.read_concern, ) + self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name) + self._opened = False self._closed = False - self._init_background() + self._loop: Optional[asyncio.AbstractEventLoop] = None + if not is_srv: + self._init_background() if _IS_SYNC and connect: self._get_topology() # type: ignore[unused-coroutine] - self._encrypter = None + async def _resolve_srv(self) -> None: + keyword_opts = self._resolve_srv_info["keyword_opts"] + seeds = set() + opts = common._CaseInsensitiveDictionary() + srv_service_name = keyword_opts.get("srvservicename") + srv_max_hosts = keyword_opts.get("srvmaxhosts") + for entity in self._host: + # A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/' + # it must be a URI, + # https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names + if "/" in entity: + # Determine connection timeout from kwargs. + timeout = keyword_opts.get("connecttimeoutms") + if timeout is not None: + timeout = common.validate_timeout_or_none_or_zero( + keyword_opts.cased_key("connecttimeoutms"), timeout + ) + res = await uri_parser._parse_srv( + entity, + self._port, + validate=True, + warn=True, + normalize=False, + connect_timeout=timeout, + srv_service_name=srv_service_name, + srv_max_hosts=srv_max_hosts, + ) + seeds.update(res["nodelist"]) + opts = res["options"] + else: + seeds.update(split_hosts(entity, self._port)) + + if not seeds: + raise ConfigurationError("need to specify at least one host") + + for hostname in [node[0] for node in seeds]: + if _detect_external_db(hostname): + break + + # Add options with named keyword arguments to the parsed kwarg options. + tz_aware = keyword_opts["tz_aware"] + connect = keyword_opts["connect"] + if tz_aware is None: + tz_aware = opts.get("tz_aware", False) + if connect is None: + # Default to connect=True unless on a FaaS system, which might use fork. + from pymongo.pool_options import _is_faas + + connect = opts.get("connect", not _is_faas()) + keyword_opts["tz_aware"] = tz_aware + keyword_opts["connect"] = connect + + opts = self._validate_kwargs_and_update_opts(keyword_opts, opts) + + if srv_service_name is None: + srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME) + + srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts") + opts = self._normalize_and_validate_options(opts, seeds) + + # Username and password passed as kwargs override user info in URI. + username = opts.get("username", self._resolve_srv_info["username"]) + password = opts.get("password", self._resolve_srv_info["password"]) + self._options = ClientOptions( + username, password, self._resolve_srv_info["dbase"], opts, _IS_SYNC + ) + + self._init_based_on_options(seeds, srv_max_hosts, srv_service_name) + + def _init_based_on_options( + self, seeds: Collection[tuple[str, int]], srv_max_hosts: Any, srv_service_name: Any + ) -> None: + self._event_listeners = self._options.pool_options._event_listeners + self._topology_settings = TopologySettings( + seeds=seeds, + replica_set_name=self._options.replica_set_name, + pool_class=self._resolve_srv_info["pool_class"], + pool_options=self._options.pool_options, + monitor_class=self._resolve_srv_info["monitor_class"], + condition_class=self._resolve_srv_info["condition_class"], + local_threshold_ms=self._options.local_threshold_ms, + server_selection_timeout=self._options.server_selection_timeout, + server_selector=self._options.server_selector, + heartbeat_frequency=self._options.heartbeat_frequency, + fqdn=self._resolve_srv_info["fqdn"], + direct_connection=self._options.direct_connection, + load_balanced=self._options.load_balanced, + srv_service_name=srv_service_name, + srv_max_hosts=srv_max_hosts, + server_monitoring_mode=self._options.server_monitoring_mode, + topology_id=self._topology_settings._topology_id if self._topology_settings else None, + ) if self._options.auto_encryption_opts: from pymongo.asynchronous.encryption import _Encrypter self._encrypter = _Encrypter(self, self._options.auto_encryption_opts) self._timeout = self._options.timeout - if _HAS_REGISTER_AT_FORK: - # Add this client to the list of weakly referenced items. - # This will be used later if we fork. - AsyncMongoClient._clients[self._topology._topology_id] = self + def _normalize_and_validate_options( + self, opts: common._CaseInsensitiveDictionary, seeds: set[tuple[str, int | None]] + ) -> common._CaseInsensitiveDictionary: + # Handle security-option conflicts in combined options. + opts = _handle_security_options(opts) + # Normalize combined options. + opts = _normalize_options(opts) + _check_options(seeds, opts) + return opts + + def _validate_kwargs_and_update_opts( + self, + keyword_opts: common._CaseInsensitiveDictionary, + opts: common._CaseInsensitiveDictionary, + ) -> common._CaseInsensitiveDictionary: + # Handle deprecated options in kwarg options. + keyword_opts = _handle_option_deprecations(keyword_opts) + # Validate kwarg options. + keyword_opts = common._CaseInsensitiveDictionary( + dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items()) + ) + # Override connection string options with kwarg options. + opts.update(keyword_opts) + return opts async def aconnect(self) -> None: """Explicitly connect to MongoDB asynchronously instead of on the first operation.""" @@ -902,6 +1014,10 @@ async def aconnect(self) -> None: def _init_background(self, old_pid: Optional[int] = None) -> None: self._topology = Topology(self._topology_settings) + if _HAS_REGISTER_AT_FORK: + # Add this client to the list of weakly referenced items. + # This will be used later if we fork. + AsyncMongoClient._clients[self._topology._topology_id] = self # Seed the topology with the old one's pid so we can detect clients # that are opened before a fork and used after. self._topology._pid = old_pid @@ -1090,6 +1206,16 @@ def topology_description(self) -> TopologyDescription: .. versionadded:: 4.0 """ + if self._topology is None: + servers = {(host, port): ServerDescription((host, port)) for host, port in self._seeds} + return TopologyDescription( + TOPOLOGY_TYPE.Unknown, + servers, + None, + None, + None, + self._topology_settings, + ) return self._topology.description @property @@ -1103,6 +1229,8 @@ def nodes(self) -> FrozenSet[_Address]: to any servers, or a network partition causes it to lose connection to all servers. """ + if self._topology is None: + return frozenset() description = self._topology.description return frozenset(s.address for s in description.known_servers) @@ -1116,16 +1244,24 @@ def options(self) -> ClientOptions: """ return self._options + def eq_props(self) -> tuple[tuple[_Address, ...], Optional[str], Optional[str], str]: + return ( + tuple(sorted(self._resolve_srv_info["seeds"])), + self._options.replica_set_name, + self._resolve_srv_info["fqdn"], + self._resolve_srv_info["srv_service_name"], + ) + def __eq__(self, other: Any) -> bool: if isinstance(other, self.__class__): - return self._topology == other._topology + return self.eq_props() == other.eq_props() return NotImplemented def __ne__(self, other: Any) -> bool: return not self == other def __hash__(self) -> int: - return hash(self._topology) + return hash(self.eq_props()) def _repr_helper(self) -> str: def option_repr(option: str, value: Any) -> str: @@ -1141,13 +1277,16 @@ def option_repr(option: str, value: Any) -> str: return f"{option}={value!r}" # Host first... - options = [ - "host=%r" - % [ - "%s:%d" % (host, port) if port is not None else host - for host, port in self._topology_settings.seeds + if self._topology is None: + options = [f"host='mongodb+srv://{self._resolve_srv_info['fqdn']}'"] + else: + options = [ + "host=%r" + % [ + "%s:%d" % (host, port) if port is not None else host + for host, port in self._topology_settings.seeds + ] ] - ] # ... then everything in self._constructor_args... options.extend( option_repr(key, self._options._options[key]) for key in self._constructor_args @@ -1450,6 +1589,8 @@ async def address(self) -> Optional[tuple[str, int]]: .. versionadded:: 3.0 """ + if self._topology is None: + await self._get_topology() topology_type = self._topology._description.topology_type if ( topology_type == TOPOLOGY_TYPE.Sharded @@ -1472,6 +1613,8 @@ async def primary(self) -> Optional[tuple[str, int]]: .. versionadded:: 3.0 AsyncMongoClient gained this property in version 3.0. """ + if self._topology is None: + await self._get_topology() return await self._topology.get_primary() # type: ignore[return-value] @property @@ -1485,6 +1628,8 @@ async def secondaries(self) -> set[_Address]: .. versionadded:: 3.0 AsyncMongoClient gained this property in version 3.0. """ + if self._topology is None: + await self._get_topology() return await self._topology.get_secondaries() @property @@ -1495,6 +1640,8 @@ async def arbiters(self) -> set[_Address]: connected to a replica set, there are no arbiters, or this client was created without the `replicaSet` option. """ + if self._topology is None: + await self._get_topology() return await self._topology.get_arbiters() @property @@ -1553,6 +1700,8 @@ async def close(self) -> None: .. versionchanged:: 3.6 End all server sessions created by this client. """ + if self._topology is None: + return session_ids = self._topology.pop_all_sessions() if session_ids: await self._end_sessions(session_ids) @@ -1565,6 +1714,12 @@ async def close(self) -> None: # TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened. await self._encrypter.close() self._closed = True + if not _IS_SYNC: + await asyncio.gather( + self._topology.cleanup_monitors(), # type: ignore[func-returns-value] + self._kill_cursors_executor.join(), # type: ignore[func-returns-value] + return_exceptions=True, + ) if not _IS_SYNC: # Add support for contextlib.aclosing. @@ -1576,7 +1731,17 @@ async def _get_topology(self) -> Topology: If this client was created with "connect=False", calling _get_topology launches the connection process in the background. """ + if not _IS_SYNC: + if self._loop is None: + self._loop = asyncio.get_running_loop() + elif self._loop != asyncio.get_running_loop(): + raise RuntimeError( + "Cannot use AsyncMongoClient in different event loop. AsyncMongoClient uses low-level asyncio APIs that bind it to the event loop it was created on." + ) if not self._opened: + if self._resolve_srv_info["is_srv"]: + await self._resolve_srv() + self._init_background() await self._topology.open() async with self._lock: self._kill_cursors_executor.open() @@ -1951,7 +2116,7 @@ async def _cleanup_cursor_lock( # exhausted the result set we *must* close the socket # to stop the server from sending more data. assert conn_mgr.conn is not None - conn_mgr.conn.close_conn(ConnectionClosedReason.ERROR) + await conn_mgr.conn.close_conn(ConnectionClosedReason.ERROR) else: await self._close_cursor_now(cursor_id, address, session=session, conn_mgr=conn_mgr) if conn_mgr: @@ -1971,7 +2136,7 @@ async def _close_cursor_now( The cursor is closed synchronously on the current thread. """ if not isinstance(cursor_id, int): - raise TypeError("cursor_id must be an instance of int") + raise TypeError(f"cursor_id must be an instance of int, not {type(cursor_id)}") try: if conn_mgr: @@ -2038,15 +2203,13 @@ async def _process_kill_cursors(self) -> None: for address, cursor_id, conn_mgr in pinned_cursors: try: await self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False) - except asyncio.CancelledError: - raise except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: # Raise the exception when client is closed so that it # can be caught in _process_periodic_tasks raise else: - helpers_shared._handle_exception() + _log_client_error() # Don't re-open topology if it's closed and there's no pending cursors. if address_to_cursor_ids: @@ -2054,13 +2217,11 @@ async def _process_kill_cursors(self) -> None: for address, cursor_ids in address_to_cursor_ids.items(): try: await self._kill_cursors(cursor_ids, address, topology, session=None) - except asyncio.CancelledError: - raise except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: raise else: - helpers_shared._handle_exception() + _log_client_error() # This method is run periodically by a background thread. async def _process_periodic_tasks(self) -> None: @@ -2070,13 +2231,11 @@ async def _process_periodic_tasks(self) -> None: try: await self._process_kill_cursors() await self._topology.update_pool() - except asyncio.CancelledError: - raise except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: return else: - helpers_shared._handle_exception() + _log_client_error() def _return_server_session( self, server_session: Union[_ServerSession, _EmptyServerSession] @@ -2093,7 +2252,9 @@ async def _tmp_session( """If provided session is None, lend a temporary session.""" if session is not None: if not isinstance(session, client_session.AsyncClientSession): - raise ValueError("'session' argument must be an AsyncClientSession or None.") + raise ValueError( + f"'session' argument must be an AsyncClientSession or None, not {type(session)}" + ) # Don't call end_session. yield session return @@ -2247,7 +2408,9 @@ async def drop_database( name = name.name if not isinstance(name, str): - raise TypeError("name_or_database must be an instance of str or a AsyncDatabase") + raise TypeError( + f"name_or_database must be an instance of str or a AsyncDatabase, not {type(name)}" + ) async with await self._conn_for_writes(session, operation=_Op.DROP_DATABASE) as conn: await self[name]._command( @@ -2508,6 +2671,7 @@ async def handle( self.completed_handshake, self.service_id, ) + assert self.client._topology is not None await self.client._topology.handle_error(self.server_address, err_ctx) async def __aenter__(self) -> _MongoClientErrorHandler: @@ -2557,6 +2721,7 @@ def __init__( self._deprioritized_servers: list[Server] = [] self._operation = operation self._operation_id = operation_id + self._attempt_number = 0 async def run(self) -> T: """Runs the supplied func() and attempts a retry @@ -2599,6 +2764,7 @@ async def run(self) -> T: raise self._retrying = True self._last_error = exc + self._attempt_number += 1 else: raise @@ -2620,6 +2786,7 @@ async def run(self) -> T: raise self._last_error from exc else: raise + self._attempt_number += 1 if self._bulk: self._bulk.retrying = True else: @@ -2698,6 +2865,14 @@ async def _write(self) -> T: # not support sessions raise the last error. self._check_last_error() self._retryable = False + if self._retrying: + _debug_log( + _COMMAND_LOGGER, + message=f"Retrying write attempt number {self._attempt_number}", + clientId=self._client._topology_settings._topology_id, + commandName=self._operation, + operationId=self._operation_id, + ) return await self._func(self._session, conn, self._retryable) # type: ignore except PyMongoError as exc: if not self._retryable: @@ -2719,6 +2894,14 @@ async def _read(self) -> T: ): if self._retrying and not self._retryable: self._check_last_error() + if self._retrying: + _debug_log( + _COMMAND_LOGGER, + message=f"Retrying read attempt number {self._attempt_number}", + clientId=self._client._topology_settings._topology_id, + commandName=self._operation, + operationId=self._operation_id, + ) return await self._func(self._session, self._server, conn, read_pref) # type: ignore diff --git a/pymongo/asynchronous/monitor.py b/pymongo/asynchronous/monitor.py index ad1bc70aba..32b545380a 100644 --- a/pymongo/asynchronous/monitor.py +++ b/pymongo/asynchronous/monitor.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -21,11 +21,12 @@ import logging import time import weakref -from typing import TYPE_CHECKING, Any, Mapping, Optional, cast +from typing import TYPE_CHECKING, Any, Optional from pymongo import common, periodic_executor from pymongo._csot import MovingMinimum -from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled +from pymongo.asynchronous.srv_resolver import _SrvResolver +from pymongo.errors import NetworkTimeout, _OperationCancelled from pymongo.hello import Hello from pymongo.lock import _async_create_lock from pymongo.logger import _SDAM_LOGGER, _debug_log, _SDAMStatusMessage @@ -33,10 +34,13 @@ from pymongo.pool_options import _is_faas from pymongo.read_preferences import MovingAverage from pymongo.server_description import ServerDescription -from pymongo.srv_resolver import _SrvResolver if TYPE_CHECKING: - from pymongo.asynchronous.pool import AsyncConnection, Pool, _CancellationContext + from pymongo.asynchronous.pool import ( # type: ignore[attr-defined] + AsyncConnection, + Pool, + _CancellationContext, + ) from pymongo.asynchronous.settings import TopologySettings from pymongo.asynchronous.topology import Topology @@ -112,9 +116,9 @@ async def close(self) -> None: """ self.gc_safe_close() - async def join(self, timeout: Optional[int] = None) -> None: + async def join(self) -> None: """Wait for the monitor to stop.""" - await self._executor.join(timeout) + await self._executor.join() def request_check(self) -> None: """If the monitor is sleeping, wake it soon.""" @@ -189,6 +193,11 @@ def gc_safe_close(self) -> None: self._rtt_monitor.gc_safe_close() self.cancel_check() + async def join(self) -> None: + await asyncio.gather( + self._executor.join(), self._rtt_monitor.join(), return_exceptions=True + ) # type: ignore[func-returns-value] + async def close(self) -> None: self.gc_safe_close() await self._rtt_monitor.close() @@ -250,15 +259,7 @@ async def _check_server(self) -> ServerDescription: self._conn_id = None start = time.monotonic() try: - try: - return await self._check_once() - except (OperationFailure, NotPrimaryError) as exc: - # Update max cluster time even when hello fails. - details = cast(Mapping[str, Any], exc.details) - await self._topology.receive_cluster_time(details.get("$clusterTime")) - raise - except asyncio.CancelledError: - raise + return await self._check_once() except ReferenceError: raise except Exception as error: @@ -273,6 +274,7 @@ async def _check_server(self) -> ServerDescription: if _SDAM_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _SDAM_LOGGER, + message=_SDAMStatusMessage.HEARTBEAT_FAIL, topologyId=self._topology._topology_id, serverHost=address[0], serverPort=address[1], @@ -280,7 +282,6 @@ async def _check_server(self) -> ServerDescription: durationMS=duration * 1000, failure=error, driverConnectionId=self._conn_id, - message=_SDAMStatusMessage.HEARTBEAT_FAIL, ) await self._reset_connection() if isinstance(error, _OperationCancelled): @@ -312,13 +313,13 @@ async def _check_once(self) -> ServerDescription: if _SDAM_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _SDAM_LOGGER, + message=_SDAMStatusMessage.HEARTBEAT_START, topologyId=self._topology._topology_id, driverConnectionId=conn.id, serverConnectionId=conn.server_connection_id, serverHost=address[0], serverPort=address[1], awaited=awaited, - message=_SDAMStatusMessage.HEARTBEAT_START, ) self._cancel_context = conn.cancel_context @@ -338,6 +339,7 @@ async def _check_once(self) -> ServerDescription: if _SDAM_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _SDAM_LOGGER, + message=_SDAMStatusMessage.HEARTBEAT_SUCCESS, topologyId=self._topology._topology_id, driverConnectionId=conn.id, serverConnectionId=conn.server_connection_id, @@ -346,7 +348,6 @@ async def _check_once(self) -> ServerDescription: awaited=awaited, durationMS=round_trip_time * 1000, reply=response.document, - message=_SDAMStatusMessage.HEARTBEAT_SUCCESS, ) return sd @@ -355,7 +356,6 @@ async def _check_with_socket(self, conn: AsyncConnection) -> tuple[Hello, float] Can raise ConnectionFailure or OperationFailure. """ - cluster_time = self._topology.max_cluster_time() start = time.monotonic() if conn.more_to_come: # Read the next streaming hello (MongoDB 4.4+). @@ -365,13 +365,12 @@ async def _check_with_socket(self, conn: AsyncConnection) -> tuple[Hello, float] ): # Initiate streaming hello (MongoDB 4.4+). response = await conn._hello( - cluster_time, self._server_description.topology_version, self._settings.heartbeat_frequency, ) else: # New connection handshake or polling hello (MongoDB <4.4). - response = await conn._hello(cluster_time, None, None) + response = await conn._hello(None, None) duration = _monotonic_duration(start) return response, duration @@ -400,7 +399,7 @@ async def _run(self) -> None: # Don't poll right after creation, wait 60 seconds first if time.monotonic() < self._startup_time + common.MIN_SRV_RESCAN_INTERVAL: return - seedlist = self._get_seedlist() + seedlist = await self._get_seedlist() if seedlist: self._seedlist = seedlist try: @@ -409,7 +408,7 @@ async def _run(self) -> None: # Topology was garbage-collected. await self.close() - def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]: + async def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]: """Poll SRV records for a seedlist. Returns a list of ServerDescriptions. @@ -420,12 +419,10 @@ def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]: self._settings.pool_options.connect_timeout, self._settings.srv_service_name, ) - seedlist, ttl = resolver.get_hosts_and_min_ttl() + seedlist, ttl = await resolver.get_hosts_and_min_ttl() if len(seedlist) == 0: # As per the spec: this should be treated as a failure. raise Exception - except asyncio.CancelledError: - raise except Exception: # As per the spec, upon encountering an error: # - An error must not be raised @@ -489,8 +486,6 @@ async def _run(self) -> None: except ReferenceError: # Topology was garbage-collected. await self.close() - except asyncio.CancelledError: - raise except Exception: await self._pool.reset() diff --git a/pymongo/asynchronous/network.py b/pymongo/asynchronous/network.py index d17aead120..1605efe92d 100644 --- a/pymongo/asynchronous/network.py +++ b/pymongo/asynchronous/network.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -17,7 +17,6 @@ import datetime import logging -import time from typing import ( TYPE_CHECKING, Any, @@ -31,20 +30,16 @@ from bson import _decode_all_selective from pymongo import _csot, helpers_shared, message -from pymongo.common import MAX_MESSAGE_SIZE -from pymongo.compression_support import _NO_COMPRESSION, decompress +from pymongo.compression_support import _NO_COMPRESSION from pymongo.errors import ( NotPrimaryError, OperationFailure, - ProtocolError, ) from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log -from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply +from pymongo.message import _OpMsg from pymongo.monitoring import _is_speculative_authenticate from pymongo.network_layer import ( - _UNPACK_COMPRESSION_HEADER, - _UNPACK_HEADER, - async_receive_data, + async_receive_message, async_sendall, ) @@ -168,8 +163,8 @@ async def command( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.STARTED, + clientId=client._topology_settings._topology_id, command=spec, commandName=next(iter(spec)), databaseName=dbname, @@ -194,19 +189,23 @@ async def command( ) try: - await async_sendall(conn.conn, msg) + await async_sendall(conn.conn.get_conn, msg) if use_op_msg and unacknowledged: # Unacknowledged, fake a successful command response. reply = None response_doc: _DocumentOut = {"ok": 1} else: - reply = await receive_message(conn, request_id) + reply = await async_receive_message(conn, request_id) conn.more_to_come = reply.more_to_come unpacked_docs = reply.unpack_response( codec_options=codec_options, user_fields=user_fields ) response_doc = unpacked_docs[0] + if not conn.ready: + cluster_time = response_doc.get("$clusterTime") + if cluster_time: + conn._cluster_time = cluster_time if client: await client._process_response(response_doc, session) if check: @@ -226,8 +225,8 @@ async def command( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.FAILED, + clientId=client._topology_settings._topology_id, durationMS=duration, failure=failure, commandName=next(iter(spec)), @@ -260,8 +259,8 @@ async def command( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.SUCCEEDED, + clientId=client._topology_settings._topology_id, durationMS=duration, reply=response_doc, commandName=next(iter(spec)), @@ -297,47 +296,3 @@ async def command( ) return response_doc # type: ignore[return-value] - - -async def receive_message( - conn: AsyncConnection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE -) -> Union[_OpReply, _OpMsg]: - """Receive a raw BSON message or raise socket.error.""" - if _csot.get_timeout(): - deadline = _csot.get_deadline() - else: - timeout = conn.conn.gettimeout() - if timeout: - deadline = time.monotonic() + timeout - else: - deadline = None - # Ignore the response's request id. - length, _, response_to, op_code = _UNPACK_HEADER(await async_receive_data(conn, 16, deadline)) - # No request_id for exhaust cursor "getMore". - if request_id is not None: - if request_id != response_to: - raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}") - if length <= 16: - raise ProtocolError( - f"Message length ({length!r}) not longer than standard message header size (16)" - ) - if length > max_message_size: - raise ProtocolError( - f"Message length ({length!r}) is larger than server max " - f"message size ({max_message_size!r})" - ) - if op_code == 2012: - op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER( - await async_receive_data(conn, 9, deadline) - ) - data = decompress(await async_receive_data(conn, length - 25, deadline), compressor_id) - else: - data = await async_receive_data(conn, length - 16, deadline) - - try: - unpack_reply = _UNPACK_REPLY[op_code] - except KeyError: - raise ProtocolError( - f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}" - ) from None - return unpack_reply(data) diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index bf2f2b4946..f4d5b174fa 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -17,11 +17,8 @@ import asyncio import collections import contextlib -import functools import logging import os -import socket -import ssl import sys import time import weakref @@ -40,8 +37,8 @@ from bson import DEFAULT_CODEC_OPTIONS from pymongo import _csot, helpers_shared from pymongo.asynchronous.client_session import _validate_session_write_concern -from pymongo.asynchronous.helpers import _getaddrinfo, _handle_reauth -from pymongo.asynchronous.network import command, receive_message +from pymongo.asynchronous.helpers import _handle_reauth +from pymongo.asynchronous.network import command from pymongo.common import ( MAX_BSON_SIZE, MAX_MESSAGE_SIZE, @@ -52,16 +49,13 @@ from pymongo.errors import ( # type:ignore[attr-defined] AutoReconnect, ConfigurationError, - ConnectionFailure, DocumentTooLarge, ExecutionTimeout, InvalidOperation, - NetworkTimeout, NotPrimaryError, OperationFailure, PyMongoError, WaitQueueTimeoutError, - _CertificateError, ) from pymongo.hello import Hello, HelloCompat from pymongo.lock import ( @@ -79,13 +73,20 @@ ConnectionCheckOutFailedReason, ConnectionClosedReason, ) -from pymongo.network_layer import async_sendall +from pymongo.network_layer import AsyncNetworkingInterface, async_receive_message, async_sendall from pymongo.pool_options import PoolOptions +from pymongo.pool_shared import ( + SSLErrors, + _CancellationContext, + _configured_protocol_interface, + _get_timeout_details, + _raise_connection_failure, + format_timeout_details, +) from pymongo.read_preferences import ReadPreference from pymongo.server_api import _add_to_command from pymongo.server_type import SERVER_TYPE from pymongo.socket_checker import SocketChecker -from pymongo.ssl_support import HAS_SNI, SSLError if TYPE_CHECKING: from bson import CodecOptions @@ -99,10 +100,9 @@ ZstdContext, ) from pymongo.message import _OpMsg, _OpReply - from pymongo.pyopenssl_context import _sslConn from pymongo.read_concern import ReadConcern from pymongo.read_preferences import _ServerMode - from pymongo.typings import ClusterTime, _Address, _CollationIn + from pymongo.typings import _Address, _CollationIn from pymongo.write_concern import WriteConcern try: @@ -123,133 +123,6 @@ def _set_non_inheritable_non_atomic(fd: int) -> None: # noqa: ARG001 _IS_SYNC = False -_MAX_TCP_KEEPIDLE = 120 -_MAX_TCP_KEEPINTVL = 10 -_MAX_TCP_KEEPCNT = 9 - -if sys.platform == "win32": - try: - import _winreg as winreg - except ImportError: - import winreg - - def _query(key, name, default): - try: - value, _ = winreg.QueryValueEx(key, name) - # Ensure the value is a number or raise ValueError. - return int(value) - except (OSError, ValueError): - # QueryValueEx raises OSError when the key does not exist (i.e. - # the system is using the Windows default value). - return default - - try: - with winreg.OpenKey( - winreg.HKEY_LOCAL_MACHINE, r"SYSTEM\CurrentControlSet\Services\Tcpip\Parameters" - ) as key: - _WINDOWS_TCP_IDLE_MS = _query(key, "KeepAliveTime", 7200000) - _WINDOWS_TCP_INTERVAL_MS = _query(key, "KeepAliveInterval", 1000) - except OSError: - # We could not check the default values because winreg.OpenKey failed. - # Assume the system is using the default values. - _WINDOWS_TCP_IDLE_MS = 7200000 - _WINDOWS_TCP_INTERVAL_MS = 1000 - - def _set_keepalive_times(sock): - idle_ms = min(_WINDOWS_TCP_IDLE_MS, _MAX_TCP_KEEPIDLE * 1000) - interval_ms = min(_WINDOWS_TCP_INTERVAL_MS, _MAX_TCP_KEEPINTVL * 1000) - if idle_ms < _WINDOWS_TCP_IDLE_MS or interval_ms < _WINDOWS_TCP_INTERVAL_MS: - sock.ioctl(socket.SIO_KEEPALIVE_VALS, (1, idle_ms, interval_ms)) - -else: - - def _set_tcp_option(sock: socket.socket, tcp_option: str, max_value: int) -> None: - if hasattr(socket, tcp_option): - sockopt = getattr(socket, tcp_option) - try: - # PYTHON-1350 - NetBSD doesn't implement getsockopt for - # TCP_KEEPIDLE and friends. Don't attempt to set the - # values there. - default = sock.getsockopt(socket.IPPROTO_TCP, sockopt) - if default > max_value: - sock.setsockopt(socket.IPPROTO_TCP, sockopt, max_value) - except OSError: - pass - - def _set_keepalive_times(sock: socket.socket) -> None: - _set_tcp_option(sock, "TCP_KEEPIDLE", _MAX_TCP_KEEPIDLE) - _set_tcp_option(sock, "TCP_KEEPINTVL", _MAX_TCP_KEEPINTVL) - _set_tcp_option(sock, "TCP_KEEPCNT", _MAX_TCP_KEEPCNT) - - -def _raise_connection_failure( - address: Any, - error: Exception, - msg_prefix: Optional[str] = None, - timeout_details: Optional[dict[str, float]] = None, -) -> NoReturn: - """Convert a socket.error to ConnectionFailure and raise it.""" - host, port = address - # If connecting to a Unix socket, port will be None. - if port is not None: - msg = "%s:%d: %s" % (host, port, error) - else: - msg = f"{host}: {error}" - if msg_prefix: - msg = msg_prefix + msg - if "configured timeouts" not in msg: - msg += format_timeout_details(timeout_details) - if isinstance(error, socket.timeout): - raise NetworkTimeout(msg) from error - elif isinstance(error, SSLError) and "timed out" in str(error): - # Eventlet does not distinguish TLS network timeouts from other - # SSLErrors (https://github.com/eventlet/eventlet/issues/692). - # Luckily, we can work around this limitation because the phrase - # 'timed out' appears in all the timeout related SSLErrors raised. - raise NetworkTimeout(msg) from error - else: - raise AutoReconnect(msg) from error - - -def _get_timeout_details(options: PoolOptions) -> dict[str, float]: - details = {} - timeout = _csot.get_timeout() - socket_timeout = options.socket_timeout - connect_timeout = options.connect_timeout - if timeout: - details["timeoutMS"] = timeout * 1000 - if socket_timeout and not timeout: - details["socketTimeoutMS"] = socket_timeout * 1000 - if connect_timeout: - details["connectTimeoutMS"] = connect_timeout * 1000 - return details - - -def format_timeout_details(details: Optional[dict[str, float]]) -> str: - result = "" - if details: - result += " (configured timeouts:" - for timeout in ["socketTimeoutMS", "timeoutMS", "connectTimeoutMS"]: - if timeout in details: - result += f" {timeout}: {details[timeout]}ms," - result = result[:-1] - result += ")" - return result - - -class _CancellationContext: - def __init__(self) -> None: - self._cancelled = False - - def cancel(self) -> None: - """Cancel this context.""" - self._cancelled = True - - @property - def cancelled(self) -> bool: - """Was cancel called?""" - return self._cancelled - class AsyncConnection: """Store a connection with some metadata. @@ -261,7 +134,11 @@ class AsyncConnection: """ def __init__( - self, conn: Union[socket.socket, _sslConn], pool: Pool, address: tuple[str, int], id: int + self, + conn: AsyncNetworkingInterface, + pool: Pool, + address: tuple[str, int], + id: int, ): self.pool_ref = weakref.ref(pool) self.conn = conn @@ -310,13 +187,15 @@ def __init__( self.connect_rtt = 0.0 self._client_id = pool._client_id self.creation_time = time.monotonic() + # For gossiping $clusterTime from the connection handshake to the client. + self._cluster_time = None def set_conn_timeout(self, timeout: Optional[float]) -> None: """Cache last timeout to avoid duplicate calls to conn.settimeout.""" if timeout == self.last_timeout: return self.last_timeout = timeout - self.conn.settimeout(timeout) + self.conn.get_conn.settimeout(timeout) def apply_timeout( self, client: AsyncMongoClient, cmd: Optional[MutableMapping[str, Any]] @@ -362,7 +241,7 @@ async def unpin(self) -> None: if pool: await pool.checkin(self) else: - self.close_conn(ConnectionClosedReason.STALE) + await self.close_conn(ConnectionClosedReason.STALE) def hello_cmd(self) -> dict[str, Any]: # Handshake spec requires us to use OP_MSG+hello command for the @@ -374,11 +253,10 @@ def hello_cmd(self) -> dict[str, Any]: return {HelloCompat.LEGACY_CMD: 1, "helloOk": True} async def hello(self) -> Hello: - return await self._hello(None, None, None) + return await self._hello(None, None) async def _hello( self, - cluster_time: Optional[ClusterTime], topology_version: Optional[Any], heartbeat_frequency: Optional[int], ) -> Hello[dict[str, Any]]: @@ -401,9 +279,6 @@ async def _hello( if self.opts.connect_timeout: self.set_conn_timeout(self.opts.connect_timeout + heartbeat_frequency) - if not performing_handshake and cluster_time is not None: - cmd["$clusterTime"] = cluster_time - creds = self.opts._credentials if creds: if creds.mechanism == "DEFAULT" and creds.username: @@ -559,9 +434,9 @@ async def command( ) except (OperationFailure, NotPrimaryError): raise - # Catch socket.error, KeyboardInterrupt, etc. and close ourselves. + # Catch socket.error, KeyboardInterrupt, CancelledError, etc. and close ourselves. except BaseException as error: - self._raise_connection_failure(error) + await self._raise_connection_failure(error) async def send_message(self, message: bytes, max_doc_size: int) -> None: """Send a raw BSON message or raise ConnectionFailure. @@ -575,9 +450,10 @@ async def send_message(self, message: bytes, max_doc_size: int) -> None: ) try: - await async_sendall(self.conn, message) + await async_sendall(self.conn.get_conn, message) + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: - self._raise_connection_failure(error) + await self._raise_connection_failure(error) async def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _OpMsg]: """Receive a raw BSON message or raise ConnectionFailure. @@ -585,9 +461,10 @@ async def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _O If any exception is raised, the socket is closed. """ try: - return await receive_message(self, request_id, self.max_message_size) + return await async_receive_message(self, request_id, self.max_message_size) + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: - self._raise_connection_failure(error) + await self._raise_connection_failure(error) def _raise_if_not_writable(self, unacknowledged: bool) -> None: """Raise NotPrimaryError on unacknowledged write if this socket is not @@ -652,8 +529,8 @@ async def authenticate(self, reauthenticate: bool = False) -> None: if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _CONNECTION_LOGGER, - clientId=self._client_id, message=_ConnectionStatusMessage.CONN_READY, + clientId=self._client_id, serverHost=self.address[0], serverPort=self.address[1], driverConnectionId=self.id, @@ -673,11 +550,11 @@ def validate_session( "Can only use session with the AsyncMongoClient that started it" ) - def close_conn(self, reason: Optional[str]) -> None: + async def close_conn(self, reason: Optional[str]) -> None: """Close this connection with a reason.""" if self.closed: return - self._close_conn() + await self._close_conn() if reason: if self.enabled_for_cmap: assert self.listeners is not None @@ -685,8 +562,8 @@ def close_conn(self, reason: Optional[str]) -> None: if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _CONNECTION_LOGGER, - clientId=self._client_id, message=_ConnectionStatusMessage.CONN_CLOSED, + clientId=self._client_id, serverHost=self.address[0], serverPort=self.address[1], driverConnectionId=self.id, @@ -694,7 +571,7 @@ def close_conn(self, reason: Optional[str]) -> None: error=reason, ) - def _close_conn(self) -> None: + async def _close_conn(self) -> None: """Close this connection.""" if self.closed: return @@ -703,15 +580,16 @@ def _close_conn(self) -> None: # Note: We catch exceptions to avoid spurious errors on interpreter # shutdown. try: - self.conn.close() - except asyncio.CancelledError: - raise + await self.conn.close() except Exception: # noqa: S110 pass def conn_closed(self) -> bool: """Return True if we know socket has been closed, False otherwise.""" - return self.socket_checker.socket_closed(self.conn) + if _IS_SYNC: + return self.socket_checker.socket_closed(self.conn.get_conn) + else: + return self.conn.is_closing() def send_cluster_time( self, @@ -738,7 +616,7 @@ def idle_time_seconds(self) -> float: """Seconds since this socket was last checked into its pool.""" return time.monotonic() - self.last_checkin_time - def _raise_connection_failure(self, error: BaseException) -> NoReturn: + async def _raise_connection_failure(self, error: BaseException) -> NoReturn: # Catch *all* exceptions from socket methods and close the socket. In # regular Python, socket operations only raise socket.error, even if # the underlying cause was a Ctrl-C: a signal raised during socket.recv @@ -758,9 +636,9 @@ def _raise_connection_failure(self, error: BaseException) -> NoReturn: reason = None else: reason = ConnectionClosedReason.ERROR - self.close_conn(reason) + await self.close_conn(reason) # SSLError from PyOpenSSL inherits directly from Exception. - if isinstance(error, (IOError, OSError, SSLError)): + if isinstance(error, (IOError, OSError, *SSLErrors)): details = _get_timeout_details(self.opts) _raise_connection_failure(self.address, error, timeout_details=details) else: @@ -783,145 +661,6 @@ def __repr__(self) -> str: ) -async def _create_connection(address: _Address, options: PoolOptions) -> socket.socket: - """Given (host, port) and PoolOptions, connect and return a socket object. - - Can raise socket.error. - - This is a modified version of create_connection from CPython >= 2.7. - """ - host, port = address - - # Check if dealing with a unix domain socket - if host.endswith(".sock"): - if not hasattr(socket, "AF_UNIX"): - raise ConnectionFailure("UNIX-sockets are not supported on this system") - sock = socket.socket(socket.AF_UNIX) - # SOCK_CLOEXEC not supported for Unix sockets. - _set_non_inheritable_non_atomic(sock.fileno()) - try: - sock.connect(host) - return sock - except OSError: - sock.close() - raise - - # Don't try IPv6 if we don't support it. Also skip it if host - # is 'localhost' (::1 is fine). Avoids slow connect issues - # like PYTHON-356. - family = socket.AF_INET - if socket.has_ipv6 and host != "localhost": - family = socket.AF_UNSPEC - - err = None - for res in await _getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined] - af, socktype, proto, dummy, sa = res - # SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited - # number of platforms (newer Linux and *BSD). Starting with CPython 3.4 - # all file descriptors are created non-inheritable. See PEP 446. - try: - sock = socket.socket(af, socktype | getattr(socket, "SOCK_CLOEXEC", 0), proto) - except OSError: - # Can SOCK_CLOEXEC be defined even if the kernel doesn't support - # it? - sock = socket.socket(af, socktype, proto) - # Fallback when SOCK_CLOEXEC isn't available. - _set_non_inheritable_non_atomic(sock.fileno()) - try: - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - # CSOT: apply timeout to socket connect. - timeout = _csot.remaining() - if timeout is None: - timeout = options.connect_timeout - elif timeout <= 0: - raise socket.timeout("timed out") - sock.settimeout(timeout) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True) - _set_keepalive_times(sock) - sock.connect(sa) - return sock - except OSError as e: - err = e - sock.close() - - if err is not None: - raise err - else: - # This likely means we tried to connect to an IPv6 only - # host with an OS/kernel or Python interpreter that doesn't - # support IPv6. The test case is Jython2.5.1 which doesn't - # support IPv6 at all. - raise OSError("getaddrinfo failed") - - -async def _configured_socket( - address: _Address, options: PoolOptions -) -> Union[socket.socket, _sslConn]: - """Given (host, port) and PoolOptions, return a configured socket. - - Can raise socket.error, ConnectionFailure, or _CertificateError. - - Sets socket's SSL and timeout options. - """ - sock = await _create_connection(address, options) - ssl_context = options._ssl_context - - if ssl_context is None: - sock.settimeout(options.socket_timeout) - return sock - - host = address[0] - try: - # We have to pass hostname / ip address to wrap_socket - # to use SSLContext.check_hostname. - if HAS_SNI: - if _IS_SYNC: - ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host) - else: - if hasattr(ssl_context, "a_wrap_socket"): - ssl_sock = await ssl_context.a_wrap_socket(sock, server_hostname=host) # type: ignore[assignment, misc] - else: - loop = asyncio.get_running_loop() - ssl_sock = await loop.run_in_executor( - None, - functools.partial(ssl_context.wrap_socket, sock, server_hostname=host), # type: ignore[assignment, misc] - ) - else: - if _IS_SYNC: - ssl_sock = ssl_context.wrap_socket(sock) - else: - if hasattr(ssl_context, "a_wrap_socket"): - ssl_sock = await ssl_context.a_wrap_socket(sock) # type: ignore[assignment, misc] - else: - loop = asyncio.get_running_loop() - ssl_sock = await loop.run_in_executor(None, ssl_context.wrap_socket, sock) # type: ignore[assignment, misc] - except _CertificateError: - sock.close() - # Raise _CertificateError directly like we do after match_hostname - # below. - raise - except (OSError, SSLError) as exc: - sock.close() - # We raise AutoReconnect for transient and permanent SSL handshake - # failures alike. Permanent handshake failures, like protocol - # mismatch, will be turned into ServerSelectionTimeoutErrors later. - details = _get_timeout_details(options) - _raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details) - if ( - ssl_context.verify_mode - and not ssl_context.check_hostname - and not options.tls_allow_invalid_hostnames - ): - try: - ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined] - except _CertificateError: - ssl_sock.close() - raise - - ssl_sock.settimeout(options.socket_timeout) - return ssl_sock - - class _PoolClosedError(PyMongoError): """Internal error raised when a thread tries to get a connection from a closed pool. @@ -966,7 +705,7 @@ class PoolState: # Do *not* explicitly inherit from object or Jython won't call __del__ -# http://bugs.jython.org/issue1057 +# https://bugs.jython.org/issue1057 class Pool: def __init__( self, @@ -1039,8 +778,8 @@ def __init__( if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _CONNECTION_LOGGER, - clientId=self._client_id, message=_ConnectionStatusMessage.POOL_CREATED, + clientId=self._client_id, serverHost=self.address[0], serverPort=self.address[1], **self.opts.non_default_options, @@ -1065,8 +804,8 @@ async def ready(self) -> None: if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _CONNECTION_LOGGER, - clientId=self._client_id, message=_ConnectionStatusMessage.POOL_READY, + clientId=self._client_id, serverHost=self.address[0], serverPort=self.address[1], ) @@ -1122,16 +861,22 @@ async def _reset( # PoolClosedEvent but that reset() SHOULD close sockets *after* # publishing the PoolClearedEvent. if close: - for conn in sockets: - conn.close_conn(ConnectionClosedReason.POOL_CLOSED) + if not _IS_SYNC: + await asyncio.gather( + *[conn.close_conn(ConnectionClosedReason.POOL_CLOSED) for conn in sockets], + return_exceptions=True, + ) + else: + for conn in sockets: + await conn.close_conn(ConnectionClosedReason.POOL_CLOSED) if self.enabled_for_cmap: assert listeners is not None listeners.publish_pool_closed(self.address) if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _CONNECTION_LOGGER, - clientId=self._client_id, message=_ConnectionStatusMessage.POOL_CLOSED, + clientId=self._client_id, serverHost=self.address[0], serverPort=self.address[1], ) @@ -1147,14 +892,20 @@ async def _reset( if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _CONNECTION_LOGGER, - clientId=self._client_id, message=_ConnectionStatusMessage.POOL_CLEARED, + clientId=self._client_id, serverHost=self.address[0], serverPort=self.address[1], serviceId=service_id, ) - for conn in sockets: - conn.close_conn(ConnectionClosedReason.STALE) + if not _IS_SYNC: + await asyncio.gather( + *[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets], + return_exceptions=True, + ) + else: + for conn in sockets: + await conn.close_conn(ConnectionClosedReason.STALE) async def update_is_writable(self, is_writable: Optional[bool]) -> None: """Updates the is_writable attribute on all sockets currently in the @@ -1193,13 +944,21 @@ async def remove_stale_sockets(self, reference_generation: int) -> None: return if self.opts.max_idle_time_seconds is not None: + close_conns = [] async with self.lock: while ( self.conns and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds ): - conn = self.conns.pop() - conn.close_conn(ConnectionClosedReason.IDLE) + close_conns.append(self.conns.pop()) + if not _IS_SYNC: + await asyncio.gather( + *[conn.close_conn(ConnectionClosedReason.IDLE) for conn in close_conns], + return_exceptions=True, + ) + else: + for conn in close_conns: + await conn.close_conn(ConnectionClosedReason.IDLE) while True: async with self.size_cond: @@ -1219,14 +978,18 @@ async def remove_stale_sockets(self, reference_generation: int) -> None: self._pending += 1 incremented = True conn = await self.connect() + close_conn = False async with self.lock: # Close connection and return if the pool was reset during # socket creation or while acquiring the pool lock. if self.gen.get_overall() != reference_generation: - conn.close_conn(ConnectionClosedReason.STALE) - return - self.conns.appendleft(conn) - self.active_contexts.discard(conn.cancel_context) + close_conn = True + if not close_conn: + self.conns.appendleft(conn) + self.active_contexts.discard(conn.cancel_context) + if close_conn: + await conn.close_conn(ConnectionClosedReason.STALE) + return finally: if incremented: # Notify after adding the socket to the pool. @@ -1260,15 +1023,16 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _CONNECTION_LOGGER, - clientId=self._client_id, message=_ConnectionStatusMessage.CONN_CREATED, + clientId=self._client_id, serverHost=self.address[0], serverPort=self.address[1], driverConnectionId=conn_id, ) try: - sock = await _configured_socket(self.address, self.opts) + networking_interface = await _configured_protocol_interface(self.address, self.opts) + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: async with self.lock: self.active_contexts.discard(tmp_context) @@ -1280,21 +1044,21 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _CONNECTION_LOGGER, - clientId=self._client_id, message=_ConnectionStatusMessage.CONN_CLOSED, + clientId=self._client_id, serverHost=self.address[0], serverPort=self.address[1], driverConnectionId=conn_id, reason=_verbose_connection_error_reason(ConnectionClosedReason.ERROR), error=ConnectionClosedReason.ERROR, ) - if isinstance(error, (IOError, OSError, SSLError)): + if isinstance(error, (IOError, OSError, *SSLErrors)): details = _get_timeout_details(self.opts) _raise_connection_failure(self.address, error, timeout_details=details) raise - conn = AsyncConnection(sock, self, self.address, conn_id) # type: ignore[arg-type] + conn = AsyncConnection(networking_interface, self, self.address, conn_id) # type: ignore[arg-type] async with self.lock: self.active_contexts.add(conn.cancel_context) self.active_contexts.discard(tmp_context) @@ -1308,12 +1072,16 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A handler.contribute_socket(conn, completed_handshake=False) await conn.authenticate() + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException: async with self.lock: self.active_contexts.discard(conn.cancel_context) - conn.close_conn(ConnectionClosedReason.ERROR) + await conn.close_conn(ConnectionClosedReason.ERROR) raise + if handler: + await handler.client._topology.receive_cluster_time(conn._cluster_time) + return conn @contextlib.asynccontextmanager @@ -1343,8 +1111,8 @@ async def checkout( if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _CONNECTION_LOGGER, - clientId=self._client_id, message=_ConnectionStatusMessage.CHECKOUT_STARTED, + clientId=self._client_id, serverHost=self.address[0], serverPort=self.address[1], ) @@ -1358,8 +1126,8 @@ async def checkout( if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _CONNECTION_LOGGER, - clientId=self._client_id, message=_ConnectionStatusMessage.CHECKOUT_SUCCEEDED, + clientId=self._client_id, serverHost=self.address[0], serverPort=self.address[1], driverConnectionId=conn.id, @@ -1369,6 +1137,7 @@ async def checkout( async with self.lock: self.active_contexts.add(conn.cancel_context) yield conn + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException: # Exception in caller. Ensure the connection gets returned. # Note that when pinned is True, the session owns the @@ -1406,8 +1175,8 @@ def _raise_if_not_ready(self, checkout_started_time: float, emit_event: bool) -> if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _CONNECTION_LOGGER, - clientId=self._client_id, message=_ConnectionStatusMessage.CHECKOUT_FAILED, + clientId=self._client_id, serverHost=self.address[0], serverPort=self.address[1], reason="An error occurred while trying to establish a new connection", @@ -1440,8 +1209,8 @@ async def _get_conn( if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _CONNECTION_LOGGER, - clientId=self._client_id, message=_ConnectionStatusMessage.CHECKOUT_FAILED, + clientId=self._client_id, serverHost=self.address[0], serverPort=self.address[1], reason="Connection pool was closed", @@ -1505,7 +1274,7 @@ async def _get_conn( except IndexError: self._pending += 1 if conn: # We got a socket from the pool - if self._perished(conn): + if await self._perished(conn): conn = None continue else: # We need to create a new connection @@ -1515,10 +1284,11 @@ async def _get_conn( async with self._max_connecting_cond: self._pending -= 1 self._max_connecting_cond.notify() + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException: if conn: # We checked out a socket but authentication failed. - conn.close_conn(ConnectionClosedReason.ERROR) + await conn.close_conn(ConnectionClosedReason.ERROR) async with self.size_cond: self.requests -= 1 if incremented: @@ -1535,8 +1305,8 @@ async def _get_conn( if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _CONNECTION_LOGGER, - clientId=self._client_id, message=_ConnectionStatusMessage.CHECKOUT_FAILED, + clientId=self._client_id, serverHost=self.address[0], serverPort=self.address[1], reason="An error occurred while trying to establish a new connection", @@ -1568,8 +1338,8 @@ async def checkin(self, conn: AsyncConnection) -> None: if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _CONNECTION_LOGGER, - clientId=self._client_id, message=_ConnectionStatusMessage.CHECKEDIN, + clientId=self._client_id, serverHost=self.address[0], serverPort=self.address[1], driverConnectionId=conn.id, @@ -1578,7 +1348,7 @@ async def checkin(self, conn: AsyncConnection) -> None: await self.reset_without_pause() else: if self.closed: - conn.close_conn(ConnectionClosedReason.POOL_CLOSED) + await conn.close_conn(ConnectionClosedReason.POOL_CLOSED) elif conn.closed: # CMAP requires the closed event be emitted after the check in. if self.enabled_for_cmap: @@ -1589,8 +1359,8 @@ async def checkin(self, conn: AsyncConnection) -> None: if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _CONNECTION_LOGGER, - clientId=self._client_id, message=_ConnectionStatusMessage.CONN_CLOSED, + clientId=self._client_id, serverHost=self.address[0], serverPort=self.address[1], driverConnectionId=conn.id, @@ -1598,17 +1368,20 @@ async def checkin(self, conn: AsyncConnection) -> None: error=ConnectionClosedReason.ERROR, ) else: + close_conn = False async with self.lock: # Hold the lock to ensure this section does not race with # Pool.reset(). if self.stale_generation(conn.generation, conn.service_id): - conn.close_conn(ConnectionClosedReason.STALE) + close_conn = True else: conn.update_last_checkin_time() conn.update_is_writable(bool(self.is_writable)) self.conns.appendleft(conn) # Notify any threads waiting to create a connection. self._max_connecting_cond.notify() + if close_conn: + await conn.close_conn(ConnectionClosedReason.STALE) async with self.size_cond: if txn: @@ -1620,7 +1393,7 @@ async def checkin(self, conn: AsyncConnection) -> None: self.operation_count -= 1 self.size_cond.notify() - def _perished(self, conn: AsyncConnection) -> bool: + async def _perished(self, conn: AsyncConnection) -> bool: """Return True and close the connection if it is "perished". This side-effecty function checks if this socket has been idle for @@ -1640,18 +1413,18 @@ def _perished(self, conn: AsyncConnection) -> bool: self.opts.max_idle_time_seconds is not None and idle_time_seconds > self.opts.max_idle_time_seconds ): - conn.close_conn(ConnectionClosedReason.IDLE) + await conn.close_conn(ConnectionClosedReason.IDLE) return True if self._check_interval_seconds is not None and ( self._check_interval_seconds == 0 or idle_time_seconds > self._check_interval_seconds ): if conn.conn_closed(): - conn.close_conn(ConnectionClosedReason.ERROR) + await conn.close_conn(ConnectionClosedReason.ERROR) return True if self.stale_generation(conn.generation, conn.service_id): - conn.close_conn(ConnectionClosedReason.STALE) + await conn.close_conn(ConnectionClosedReason.STALE) return True return False @@ -1667,8 +1440,8 @@ def _raise_wait_queue_timeout(self, checkout_started_time: float) -> NoReturn: if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _CONNECTION_LOGGER, - clientId=self._client_id, message=_ConnectionStatusMessage.CHECKOUT_FAILED, + clientId=self._client_id, serverHost=self.address[0], serverPort=self.address[1], reason="Wait queue timeout elapsed without a connection becoming available", @@ -1699,5 +1472,6 @@ def __del__(self) -> None: # Avoid ResourceWarnings in Python 3 # Close all sockets without calling reset() or close() because it is # not safe to acquire a lock in __del__. - for conn in self.conns: - conn.close_conn(None) + if _IS_SYNC: + for conn in self.conns: + conn.close_conn(None) diff --git a/pymongo/asynchronous/server.py b/pymongo/asynchronous/server.py index 72f22584e2..0e0d53b96f 100644 --- a/pymongo/asynchronous/server.py +++ b/pymongo/asynchronous/server.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -108,10 +108,10 @@ async def close(self) -> None: if _SDAM_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _SDAM_LOGGER, + message=_SDAMStatusMessage.STOP_SERVER, topologyId=self._topology_id, serverHost=self._description.address[0], serverPort=self._description.address[1], - message=_SDAMStatusMessage.STOP_SERVER, ) await self._monitor.close() @@ -173,8 +173,8 @@ async def run_operation( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.STARTED, + clientId=client._topology_settings._topology_id, command=cmd, commandName=next(iter(cmd)), databaseName=dbn, @@ -234,8 +234,8 @@ async def run_operation( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.FAILED, + clientId=client._topology_settings._topology_id, durationMS=duration, failure=failure, commandName=next(iter(cmd)), @@ -278,8 +278,8 @@ async def run_operation( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.SUCCEEDED, + clientId=client._topology_settings._topology_id, durationMS=duration, reply=res, commandName=next(iter(cmd)), diff --git a/pymongo/asynchronous/settings.py b/pymongo/asynchronous/settings.py index 1103e1bd18..9c2331971a 100644 --- a/pymongo/asynchronous/settings.py +++ b/pymongo/asynchronous/settings.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -51,6 +51,7 @@ def __init__( srv_service_name: str = common.SRV_SERVICE_NAME, srv_max_hosts: int = 0, server_monitoring_mode: str = common.SERVER_MONITORING_MODE, + topology_id: Optional[ObjectId] = None, ): """Represent MongoClient's configuration. @@ -78,8 +79,10 @@ def __init__( self._srv_service_name = srv_service_name self._srv_max_hosts = srv_max_hosts or 0 self._server_monitoring_mode = server_monitoring_mode - - self._topology_id = ObjectId() + if topology_id is not None: + self._topology_id = topology_id + else: + self._topology_id = ObjectId() # Store the allocation traceback to catch unclosed clients in the # test suite. self._stack = "".join(traceback.format_stack()[:-2]) diff --git a/pymongo/asynchronous/srv_resolver.py b/pymongo/asynchronous/srv_resolver.py new file mode 100644 index 0000000000..9d1b8fe141 --- /dev/null +++ b/pymongo/asynchronous/srv_resolver.py @@ -0,0 +1,164 @@ +# Copyright 2019-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Support for resolving hosts and options from mongodb+srv:// URIs.""" +from __future__ import annotations + +import ipaddress +import random +from typing import TYPE_CHECKING, Any, Optional, Union + +from pymongo.common import CONNECT_TIMEOUT +from pymongo.errors import ConfigurationError + +if TYPE_CHECKING: + from dns import resolver + +_IS_SYNC = False + + +def _have_dnspython() -> bool: + try: + import dns # noqa: F401 + + return True + except ImportError: + return False + + +# dnspython can return bytes or str from various parts +# of its API depending on version. We always want str. +def maybe_decode(text: Union[str, bytes]) -> str: + if isinstance(text, bytes): + return text.decode() + return text + + +# PYTHON-2667 Lazily call dns.resolver methods for compatibility with eventlet. +async def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer: + if _IS_SYNC: + from dns import resolver + + if hasattr(resolver, "resolve"): + # dnspython >= 2 + return resolver.resolve(*args, **kwargs) + # dnspython 1.X + return resolver.query(*args, **kwargs) + else: + from dns import asyncresolver + + if hasattr(asyncresolver, "resolve"): + # dnspython >= 2 + return await asyncresolver.resolve(*args, **kwargs) # type:ignore[return-value] + raise ConfigurationError( + "Upgrade to dnspython version >= 2.0 to use AsyncMongoClient with mongodb+srv:// connections." + ) + + +_INVALID_HOST_MSG = ( + "Invalid URI host: %s is not a valid hostname for 'mongodb+srv://'. " + "Did you mean to use 'mongodb://'?" +) + + +class _SrvResolver: + def __init__( + self, + fqdn: str, + connect_timeout: Optional[float], + srv_service_name: str, + srv_max_hosts: int = 0, + ): + self.__fqdn = fqdn + self.__srv = srv_service_name + self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT + self.__srv_max_hosts = srv_max_hosts or 0 + # Validate the fully qualified domain name. + try: + ipaddress.ip_address(fqdn) + raise ConfigurationError(_INVALID_HOST_MSG % ("an IP address",)) + except ValueError: + pass + try: + split_fqdn = self.__fqdn.split(".") + self.__plist = split_fqdn[1:] if len(split_fqdn) > 2 else split_fqdn + except Exception: + raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) from None + self.__slen = len(self.__plist) + self.nparts = len(split_fqdn) + + async def get_options(self) -> Optional[str]: + from dns import resolver + + try: + results = await _resolve(self.__fqdn, "TXT", lifetime=self.__connect_timeout) + except (resolver.NoAnswer, resolver.NXDOMAIN): + # No TXT records + return None + except Exception as exc: + raise ConfigurationError(str(exc)) from None + if len(results) > 1: + raise ConfigurationError("Only one TXT record is supported") + return (b"&".join([b"".join(res.strings) for res in results])).decode("utf-8") # type: ignore[attr-defined] + + async def _resolve_uri(self, encapsulate_errors: bool) -> resolver.Answer: + try: + results = await _resolve( + "_" + self.__srv + "._tcp." + self.__fqdn, "SRV", lifetime=self.__connect_timeout + ) + except Exception as exc: + if not encapsulate_errors: + # Raise the original error. + raise + # Else, raise all errors as ConfigurationError. + raise ConfigurationError(str(exc)) from None + return results + + async def _get_srv_response_and_hosts( + self, encapsulate_errors: bool + ) -> tuple[resolver.Answer, list[tuple[str, Any]]]: + results = await self._resolve_uri(encapsulate_errors) + + # Construct address tuples + nodes = [ + (maybe_decode(res.target.to_text(omit_final_dot=True)), res.port) # type: ignore[attr-defined] + for res in results + ] + + # Validate hosts + for node in nodes: + srv_host = node[0].lower() + if self.__fqdn == srv_host and self.nparts < 3: + raise ConfigurationError( + "Invalid SRV host: return address is identical to SRV hostname" + ) + try: + nlist = srv_host.split(".")[1:][-self.__slen :] + except Exception: + raise ConfigurationError(f"Invalid SRV host: {node[0]}") from None + if self.__plist != nlist: + raise ConfigurationError(f"Invalid SRV host: {node[0]}") + if self.__srv_max_hosts: + nodes = random.sample(nodes, min(self.__srv_max_hosts, len(nodes))) + return results, nodes + + async def get_hosts(self) -> list[tuple[str, Any]]: + _, nodes = await self._get_srv_response_and_hosts(True) + return nodes + + async def get_hosts_and_min_ttl(self) -> tuple[list[tuple[str, Any]], int]: + results, nodes = await self._get_srv_response_and_hosts(False) + rrset = results.rrset + ttl = rrset.ttl if rrset else 0 + return nodes, ttl diff --git a/pymongo/asynchronous/topology.py b/pymongo/asynchronous/topology.py index 6d67710a7e..438dd1e352 100644 --- a/pymongo/asynchronous/topology.py +++ b/pymongo/asynchronous/topology.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -16,6 +16,7 @@ from __future__ import annotations +import asyncio import logging import os import queue @@ -29,7 +30,7 @@ from pymongo import _csot, common, helpers_shared, periodic_executor from pymongo.asynchronous.client_session import _ServerSession, _ServerSessionPool -from pymongo.asynchronous.monitor import SrvMonitor +from pymongo.asynchronous.monitor import MonitorBase, SrvMonitor from pymongo.asynchronous.pool import Pool from pymongo.asynchronous.server import Server from pymongo.errors import ( @@ -40,6 +41,7 @@ OperationFailure, PyMongoError, ServerSelectionTimeoutError, + WaitQueueTimeoutError, WriteError, ) from pymongo.hello import Hello @@ -118,8 +120,8 @@ def __init__(self, topology_settings: TopologySettings): if _SDAM_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _SDAM_LOGGER, - topologyId=self._topology_id, message=_SDAMStatusMessage.START_TOPOLOGY, + topologyId=self._topology_id, ) if self._publish_tp: @@ -150,10 +152,10 @@ def __init__(self, topology_settings: TopologySettings): if _SDAM_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _SDAM_LOGGER, + message=_SDAMStatusMessage.TOPOLOGY_CHANGE, topologyId=self._topology_id, previousDescription=initial_td, newDescription=self._description, - message=_SDAMStatusMessage.TOPOLOGY_CHANGE, ) for seed in topology_settings.seeds: @@ -163,10 +165,10 @@ def __init__(self, topology_settings: TopologySettings): if _SDAM_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _SDAM_LOGGER, + message=_SDAMStatusMessage.START_SERVER, topologyId=self._topology_id, serverHost=seed[0], serverPort=seed[1], - message=_SDAMStatusMessage.START_SERVER, ) # Store the seed list to help diagnose errors in _error_message(). @@ -207,6 +209,9 @@ async def target() -> bool: if self._settings.fqdn is not None and not self._settings.load_balanced: self._srv_monitor = SrvMonitor(self, self._settings) + # Stores all monitor tasks that need to be joined on close or server selection + self._monitor_tasks: list[MonitorBase] = [] + async def open(self) -> None: """Start monitoring, or restart after a fork. @@ -232,9 +237,7 @@ async def open(self) -> None: warnings.warn( # type: ignore[call-overload] # noqa: B028 "AsyncMongoClient opened before fork. May not be entirely fork-safe, " "proceed with caution. See PyMongo's documentation for details: " - "https://www.mongodb.com/docs/languages/" - "python/pymongo-driver/current/faq/" - "#is-pymongo-fork-safe-", + "https://dochub.mongodb.org/core/pymongo-fork-deadlock", **kwargs, ) async with self._lock: @@ -283,6 +286,10 @@ async def select_servers( else: server_timeout = server_selection_timeout + # Cleanup any completed monitor tasks safely + if not _IS_SYNC and self._monitor_tasks: + await self.cleanup_monitors() + async with self._lock: server_descriptions = await self._select_servers_loop( selector, server_timeout, operation, operation_id, address @@ -347,7 +354,7 @@ async def _select_servers_loop( operationId=operation_id, topologyDescription=self.description, clientId=self.description._topology_settings._topology_id, - remainingTimeMS=int(end_time - time.monotonic()), + remainingTimeMS=int(1000 * (end_time - time.monotonic())), ) logged_waiting = True @@ -493,7 +500,6 @@ async def _process_change( self._description = new_td await self._update_servers() - self._receive_cluster_time_no_lock(server_description.cluster_time) if self._publish_tp and not suppress_event: assert self._events is not None @@ -506,10 +512,10 @@ async def _process_change( if _SDAM_LOGGER.isEnabledFor(logging.DEBUG) and not suppress_event: _debug_log( _SDAM_LOGGER, + message=_SDAMStatusMessage.TOPOLOGY_CHANGE, topologyId=self._topology_id, previousDescription=td_old, newDescription=self._description, - message=_SDAMStatusMessage.TOPOLOGY_CHANGE, ) # Shutdown SRV polling for unsupported cluster types. @@ -520,12 +526,8 @@ async def _process_change( and self._description.topology_type not in SRV_POLLING_TOPOLOGIES ): await self._srv_monitor.close() - - # Clear the pool from a failed heartbeat. - if reset_pool: - server = self._servers.get(server_description.address) - if server: - await server.pool.reset(interrupt_connections=interrupt_connections) + if not _IS_SYNC: + self._monitor_tasks.append(self._srv_monitor) # Wake anything waiting in select_servers(). self._condition.notify_all() @@ -549,6 +551,11 @@ async def on_change( # that didn't include this server. if self._opened and self._description.has_server(server_description.address): await self._process_change(server_description, reset_pool, interrupt_connections) + # Clear the pool from a failed heartbeat, done outside the lock to avoid blocking on connection close. + if reset_pool: + server = self._servers.get(server_description.address) + if server: + await server.pool.reset(interrupt_connections=interrupt_connections) async def _process_srv_update(self, seedlist: list[tuple[str, Any]]) -> None: """Process a new seedlist on an opened topology. @@ -572,10 +579,10 @@ async def _process_srv_update(self, seedlist: list[tuple[str, Any]]) -> None: if _SDAM_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _SDAM_LOGGER, + message=_SDAMStatusMessage.TOPOLOGY_CHANGE, topologyId=self._topology_id, previousDescription=td_old, newDescription=self._description, - message=_SDAMStatusMessage.TOPOLOGY_CHANGE, ) async def on_srv_update(self, seedlist: list[tuple[str, Any]]) -> None: @@ -695,6 +702,8 @@ async def close(self) -> None: old_td = self._description for server in self._servers.values(): await server.close() + if not _IS_SYNC: + self._monitor_tasks.append(server._monitor) # Mark all servers Unknown. self._description = self._description.reset() @@ -705,6 +714,8 @@ async def close(self) -> None: # Stop SRV polling thread. if self._srv_monitor: await self._srv_monitor.close() + if not _IS_SYNC: + self._monitor_tasks.append(self._srv_monitor) self._opened = False self._closed = True @@ -734,13 +745,13 @@ async def close(self) -> None: if _SDAM_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _SDAM_LOGGER, + message=_SDAMStatusMessage.TOPOLOGY_CHANGE, topologyId=self._topology_id, previousDescription=old_td, newDescription=self._description, - message=_SDAMStatusMessage.TOPOLOGY_CHANGE, ) _debug_log( - _SDAM_LOGGER, topologyId=self._topology_id, message=_SDAMStatusMessage.STOP_TOPOLOGY + _SDAM_LOGGER, message=_SDAMStatusMessage.STOP_TOPOLOGY, topologyId=self._topology_id ) if self._publish_server or self._publish_tp: @@ -879,6 +890,8 @@ async def _handle_error(self, address: _Address, err_ctx: _ErrorContext) -> None # Clear the pool. await server.reset(service_id) elif isinstance(error, ConnectionFailure): + if isinstance(error, WaitQueueTimeoutError): + return # "Client MUST replace the server's description with type Unknown # ... MUST NOT request an immediate check of the server." if not self._settings.load_balanced: @@ -944,6 +957,8 @@ async def _update_servers(self) -> None: for address, server in list(self._servers.items()): if not self._description.has_server(address): await server.close() + if not _IS_SYNC: + self._monitor_tasks.append(server._monitor) self._servers.pop(address) def _create_pool_for_server(self, address: _Address) -> Pool: @@ -1031,6 +1046,15 @@ def _error_message(self, selector: Callable[[Selection], Selection]) -> str: else: return ",".join(str(server.error) for server in servers if server.error) + async def cleanup_monitors(self) -> None: + tasks = [] + try: + while self._monitor_tasks: + tasks.append(self._monitor_tasks.pop()) + except IndexError: + pass + await asyncio.gather(*[t.join() for t in tasks], return_exceptions=True) # type: ignore[func-returns-value] + def __repr__(self) -> str: msg = "" if not self._opened: diff --git a/pymongo/asynchronous/uri_parser.py b/pymongo/asynchronous/uri_parser.py new file mode 100644 index 0000000000..47c6d72031 --- /dev/null +++ b/pymongo/asynchronous/uri_parser.py @@ -0,0 +1,188 @@ +# Copyright 2011-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +"""Tools to parse and validate a MongoDB URI.""" +from __future__ import annotations + +from typing import Any, Optional +from urllib.parse import unquote_plus + +from pymongo.asynchronous.srv_resolver import _SrvResolver +from pymongo.common import SRV_SERVICE_NAME, _CaseInsensitiveDictionary +from pymongo.errors import ConfigurationError, InvalidURI +from pymongo.uri_parser_shared import ( + _ALLOWED_TXT_OPTS, + DEFAULT_PORT, + SCHEME, + SCHEME_LEN, + SRV_SCHEME_LEN, + _check_options, + _validate_uri, + split_hosts, + split_options, +) + +_IS_SYNC = False + + +async def parse_uri( + uri: str, + default_port: Optional[int] = DEFAULT_PORT, + validate: bool = True, + warn: bool = False, + normalize: bool = True, + connect_timeout: Optional[float] = None, + srv_service_name: Optional[str] = None, + srv_max_hosts: Optional[int] = None, +) -> dict[str, Any]: + """Parse and validate a MongoDB URI. + + Returns a dict of the form:: + + { + 'nodelist': , + 'username': or None, + 'password': or None, + 'database': or None, + 'collection': or None, + 'options': , + 'fqdn': or None + } + + If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done + to build nodelist and options. + + :param uri: The MongoDB URI to parse. + :param default_port: The port number to use when one wasn't specified + for a host in the URI. + :param validate: If ``True`` (the default), validate and + normalize all options. Default: ``True``. + :param warn: When validating, if ``True`` then will warn + the user then ignore any invalid options or values. If ``False``, + validation will error when options are unsupported or values are + invalid. Default: ``False``. + :param normalize: If ``True``, convert names of URI options + to their internally-used names. Default: ``True``. + :param connect_timeout: The maximum time in milliseconds to + wait for a response from the DNS server. + :param srv_service_name: A custom SRV service name + + .. versionchanged:: 4.6 + The delimiting slash (``/``) between hosts and connection options is now optional. + For example, "mongodb://example.com?tls=true" is now a valid URI. + + .. versionchanged:: 4.0 + To better follow RFC 3986, unquoted percent signs ("%") are no longer + supported. + + .. versionchanged:: 3.9 + Added the ``normalize`` parameter. + + .. versionchanged:: 3.6 + Added support for mongodb+srv:// URIs. + + .. versionchanged:: 3.5 + Return the original value of the ``readPreference`` MongoDB URI option + instead of the validated read preference mode. + + .. versionchanged:: 3.1 + ``warn`` added so invalid options can be ignored. + """ + result = _validate_uri(uri, default_port, validate, warn, normalize, srv_max_hosts) + result.update( + await _parse_srv( + uri, + default_port, + validate, + warn, + normalize, + connect_timeout, + srv_service_name, + srv_max_hosts, + ) + ) + return result + + +async def _parse_srv( + uri: str, + default_port: Optional[int] = DEFAULT_PORT, + validate: bool = True, + warn: bool = False, + normalize: bool = True, + connect_timeout: Optional[float] = None, + srv_service_name: Optional[str] = None, + srv_max_hosts: Optional[int] = None, +) -> dict[str, Any]: + if uri.startswith(SCHEME): + is_srv = False + scheme_free = uri[SCHEME_LEN:] + else: + is_srv = True + scheme_free = uri[SRV_SCHEME_LEN:] + + options = _CaseInsensitiveDictionary() + + host_plus_db_part, _, opts = scheme_free.partition("?") + if "/" in host_plus_db_part: + host_part, _, _ = host_plus_db_part.partition("/") + else: + host_part = host_plus_db_part + + if opts: + options.update(split_options(opts, validate, warn, normalize)) + if srv_service_name is None: + srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME) + if "@" in host_part: + _, _, hosts = host_part.rpartition("@") + else: + hosts = host_part + + hosts = unquote_plus(hosts) + srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts") + if is_srv: + nodes = split_hosts(hosts, default_port=None) + fqdn, port = nodes[0] + + # Use the connection timeout. connectTimeoutMS passed as a keyword + # argument overrides the same option passed in the connection string. + connect_timeout = connect_timeout or options.get("connectTimeoutMS") + dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts) + nodes = await dns_resolver.get_hosts() + dns_options = await dns_resolver.get_options() + if dns_options: + parsed_dns_options = split_options(dns_options, validate, warn, normalize) + if set(parsed_dns_options) - _ALLOWED_TXT_OPTS: + raise ConfigurationError( + "Only authSource, replicaSet, and loadBalanced are supported from DNS" + ) + for opt, val in parsed_dns_options.items(): + if opt not in options: + options[opt] = val + if options.get("loadBalanced") and srv_max_hosts: + raise InvalidURI("You cannot specify loadBalanced with srvMaxHosts") + if options.get("replicaSet") and srv_max_hosts: + raise InvalidURI("You cannot specify replicaSet with srvMaxHosts") + if "tls" not in options and "ssl" not in options: + options["tls"] = True if validate else "true" + else: + nodes = split_hosts(hosts, default_port=default_port) + + _check_options(nodes, options) + + return { + "nodelist": nodes, + "options": options, + } diff --git a/pymongo/auth.py b/pymongo/auth.py index a65113841d..a36f3f4233 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py index 4ac266de5f..61764b8111 100644 --- a/pymongo/auth_oidc.py +++ b/pymongo/auth_oidc.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/auth_oidc_shared.py b/pymongo/auth_oidc_shared.py index 9e0acaf6c8..d33397f52d 100644 --- a/pymongo/auth_oidc_shared.py +++ b/pymongo/auth_oidc_shared.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/auth_shared.py b/pymongo/auth_shared.py index 9534bd74ad..5a9a2b6732 100644 --- a/pymongo/auth_shared.py +++ b/pymongo/auth_shared.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -107,7 +107,7 @@ def _build_credentials_tuple( ) -> MongoCredential: """Build and return a mechanism specific credentials tuple.""" if mech not in ("MONGODB-X509", "MONGODB-AWS", "MONGODB-OIDC") and user is None: - raise ConfigurationError(f"{mech} requires a username.") + raise ConfigurationError(f"{mech} requires a username") if mech == "GSSAPI": if source is not None and source != "$external": raise ValueError("authentication source must be $external or None for GSSAPI") @@ -219,7 +219,7 @@ def _build_credentials_tuple( else: source_database = source or database or "admin" if passwd is None: - raise ConfigurationError("A password is required.") + raise ConfigurationError("A password is required") return MongoCredential(mech, source_database, user, passwd, None, _Cache()) diff --git a/pymongo/bulk_shared.py b/pymongo/bulk_shared.py index 7aa6340d55..9276419d8a 100644 --- a/pymongo/bulk_shared.py +++ b/pymongo/bulk_shared.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/change_stream.py b/pymongo/change_stream.py index b96a1750cf..f9abddec44 100644 --- a/pymongo/change_stream.py +++ b/pymongo/change_stream.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/client_options.py b/pymongo/client_options.py index 9b9b88a736..bd27dd4eb0 100644 --- a/pymongo/client_options.py +++ b/pymongo/client_options.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -84,7 +84,9 @@ def _parse_read_concern(options: Mapping[str, Any]) -> ReadConcern: return ReadConcern(concern) -def _parse_ssl_options(options: Mapping[str, Any]) -> tuple[Optional[SSLContext], bool]: +def _parse_ssl_options( + options: Mapping[str, Any], is_sync: bool +) -> tuple[Optional[SSLContext], bool]: """Parse ssl options.""" use_tls = options.get("tls") if use_tls is not None: @@ -138,6 +140,7 @@ def _parse_ssl_options(options: Mapping[str, Any]) -> tuple[Optional[SSLContext] allow_invalid_certificates, allow_invalid_hostnames, disable_ocsp_endpoint_check, + is_sync, ) return ctx, allow_invalid_hostnames return None, allow_invalid_hostnames @@ -167,7 +170,7 @@ def _parse_pool_options( compression_settings = CompressionSettings( options.get("compressors", []), options.get("zlibcompressionlevel", -1) ) - ssl_context, tls_allow_invalid_hostnames = _parse_ssl_options(options) + ssl_context, tls_allow_invalid_hostnames = _parse_ssl_options(options, is_sync) load_balanced = options.get("loadbalanced") max_connecting = options.get("maxconnecting", common.MAX_CONNECTING) return PoolOptions( diff --git a/pymongo/client_session.py b/pymongo/client_session.py index 1a3af44e12..db72b0b2e1 100644 --- a/pymongo/client_session.py +++ b/pymongo/client_session.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/collation.py b/pymongo/collation.py index 9adcb2e408..8a1eca7aff 100644 --- a/pymongo/collation.py +++ b/pymongo/collation.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -223,4 +223,4 @@ def validate_collation_or_none( return value.document if isinstance(value, dict): return value - raise TypeError("collation must be a dict, an instance of collation.Collation, or None.") + raise TypeError("collation must be a dict, an instance of collation.Collation, or None") diff --git a/pymongo/collection.py b/pymongo/collection.py index f726ed0376..16063425a7 100644 --- a/pymongo/collection.py +++ b/pymongo/collection.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/common.py b/pymongo/common.py index b442da6a3e..3d8095eedf 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -160,7 +160,7 @@ def clean_node(node: str) -> tuple[str, int]: host, port = partition_node(node) # Normalize hostname to lowercase, since DNS is case-insensitive: - # http://tools.ietf.org/html/rfc4343 + # https://tools.ietf.org/html/rfc4343 # This prevents useless rediscovery if "foo.com" is in the seed list but # "FOO.com" is in the hello response. return host.lower(), port @@ -202,7 +202,7 @@ def validate_integer(option: str, value: Any) -> int: return int(value) except ValueError: raise ValueError(f"The value of {option} must be an integer") from None - raise TypeError(f"Wrong type for {option}, value must be an integer") + raise TypeError(f"Wrong type for {option}, value must be an integer, not {type(value)}") def validate_positive_integer(option: str, value: Any) -> int: @@ -250,7 +250,7 @@ def validate_string(option: str, value: Any) -> str: """Validates that 'value' is an instance of `str`.""" if isinstance(value, str): return value - raise TypeError(f"Wrong type for {option}, value must be an instance of str") + raise TypeError(f"Wrong type for {option}, value must be an instance of str, not {type(value)}") def validate_string_or_none(option: str, value: Any) -> Optional[str]: @@ -269,7 +269,9 @@ def validate_int_or_basestring(option: str, value: Any) -> Union[int, str]: return int(value) except ValueError: return value - raise TypeError(f"Wrong type for {option}, value must be an integer or a string") + raise TypeError( + f"Wrong type for {option}, value must be an integer or a string, not {type(value)}" + ) def validate_non_negative_int_or_basestring(option: Any, value: Any) -> Union[int, str]: @@ -282,7 +284,9 @@ def validate_non_negative_int_or_basestring(option: Any, value: Any) -> Union[in except ValueError: return value return validate_non_negative_integer(option, val) - raise TypeError(f"Wrong type for {option}, value must be an non negative integer or a string") + raise TypeError( + f"Wrong type for {option}, value must be an non negative integer or a string, not {type(value)}" + ) def validate_positive_float(option: str, value: Any) -> float: @@ -365,7 +369,7 @@ def validate_max_staleness(option: str, value: Any) -> int: def validate_read_preference(dummy: Any, value: Any) -> _ServerMode: """Validate a read preference.""" if not isinstance(value, _ServerMode): - raise TypeError(f"{value!r} is not a read preference.") + raise TypeError(f"{value!r} is not a read preference") return value @@ -441,7 +445,9 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Uni props: dict[str, Any] = {} if not isinstance(value, str): if not isinstance(value, dict): - raise ValueError("Auth mechanism properties must be given as a string or a dictionary") + raise ValueError( + f"Auth mechanism properties must be given as a string or a dictionary, not {type(value)}" + ) for key, value in value.items(): # noqa: B020 if isinstance(value, str): props[key] = value @@ -453,7 +459,7 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Uni from pymongo.auth_oidc_shared import OIDCCallback if not isinstance(value, OIDCCallback): - raise ValueError("callback must be an OIDCCallback object") + raise ValueError(f"callback must be an OIDCCallback object, not {type(value)}") props[key] = value else: raise ValueError(f"Invalid type for auth mechanism property {key}, {type(value)}") @@ -476,7 +482,7 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Uni raise ValueError( f"{key} is not a supported auth " "mechanism property. Must be one of " - f"{tuple(_MECHANISM_PROPS)}." + f"{tuple(_MECHANISM_PROPS)}" ) if key == "CANONICALIZE_HOST_NAME": @@ -520,7 +526,7 @@ def validate_type_registry(option: Any, value: Any) -> Optional[TypeRegistry]: def validate_list(option: str, value: Any) -> list: """Validates that 'value' is a list.""" if not isinstance(value, list): - raise TypeError(f"{option} must be a list") + raise TypeError(f"{option} must be a list, not {type(value)}") return value @@ -587,7 +593,7 @@ def validate_server_api_or_none(option: Any, value: Any) -> Optional[ServerApi]: if value is None: return value if not isinstance(value, ServerApi): - raise TypeError(f"{option} must be an instance of ServerApi") + raise TypeError(f"{option} must be an instance of ServerApi, not {type(value)}") return value @@ -596,7 +602,7 @@ def validate_is_callable_or_none(option: Any, value: Any) -> Optional[Callable]: if value is None: return value if not callable(value): - raise ValueError(f"{option} must be a callable") + raise ValueError(f"{option} must be a callable, not {type(value)}") return value @@ -651,7 +657,7 @@ def validate_auto_encryption_opts_or_none(option: Any, value: Any) -> Optional[A from pymongo.encryption_options import AutoEncryptionOpts if not isinstance(value, AutoEncryptionOpts): - raise TypeError(f"{option} must be an instance of AutoEncryptionOpts") + raise TypeError(f"{option} must be an instance of AutoEncryptionOpts, not {type(value)}") return value @@ -668,7 +674,9 @@ def validate_datetime_conversion(option: Any, value: Any) -> Optional[DatetimeCo elif isinstance(value, int): return DatetimeConversion(value) - raise TypeError(f"{option} must be a str or int representing DatetimeConversion") + raise TypeError( + f"{option} must be a str or int representing DatetimeConversion, not {type(value)}" + ) def validate_server_monitoring_mode(option: str, value: str) -> str: @@ -928,12 +936,14 @@ def __init__( if not isinstance(write_concern, WriteConcern): raise TypeError( - "write_concern must be an instance of pymongo.write_concern.WriteConcern" + f"write_concern must be an instance of pymongo.write_concern.WriteConcern, not {type(write_concern)}" ) self._write_concern = write_concern if not isinstance(read_concern, ReadConcern): - raise TypeError("read_concern must be an instance of pymongo.read_concern.ReadConcern") + raise TypeError( + f"read_concern must be an instance of pymongo.read_concern.ReadConcern, not {type(read_concern)}" + ) self._read_concern = read_concern @property diff --git a/pymongo/compression_support.py b/pymongo/compression_support.py index f49b56cc96..db14b8d83f 100644 --- a/pymongo/compression_support.py +++ b/pymongo/compression_support.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -91,7 +91,7 @@ def validate_zlib_compression_level(option: str, value: Any) -> int: try: level = int(value) except Exception: - raise TypeError(f"{option} must be an integer, not {value!r}.") from None + raise TypeError(f"{option} must be an integer, not {value!r}") from None if level < -1 or level > 9: raise ValueError("%s must be between -1 and 9, not %d." % (option, level)) return level diff --git a/pymongo/daemon.py b/pymongo/daemon.py index b40384df13..be976decd9 100644 --- a/pymongo/daemon.py +++ b/pymongo/daemon.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/database.py b/pymongo/database.py index bbd05702dc..f85b312f91 100644 --- a/pymongo/database.py +++ b/pymongo/database.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/database_shared.py b/pymongo/database_shared.py index 2d4e37feef..d6563a4b3d 100644 --- a/pymongo/database_shared.py +++ b/pymongo/database_shared.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/driver_info.py b/pymongo/driver_info.py index 5ca3f952cd..f24321d973 100644 --- a/pymongo/driver_info.py +++ b/pymongo/driver_info.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -39,7 +39,7 @@ def __new__( for key, value in self._asdict().items(): if value is not None and not isinstance(value, str): raise TypeError( - f"Wrong type for DriverInfo {key} option, value must be an instance of str" + f"Wrong type for DriverInfo {key} option, value must be an instance of str, not {type(value)}" ) return self diff --git a/pymongo/encryption.py b/pymongo/encryption.py index 5bc2a75909..71c1d4b723 100644 --- a/pymongo/encryption.py +++ b/pymongo/encryption.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/encryption_options.py b/pymongo/encryption_options.py index ee749e7ac1..e9ad1c1e01 100644 --- a/pymongo/encryption_options.py +++ b/pymongo/encryption_options.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -20,6 +20,8 @@ from typing import TYPE_CHECKING, Any, Mapping, Optional +from pymongo.uri_parser_shared import _parse_kms_tls_options + try: import pymongocrypt # type:ignore[import-untyped] # noqa: F401 @@ -32,9 +34,9 @@ from bson import int64 from pymongo.common import validate_is_mapping from pymongo.errors import ConfigurationError -from pymongo.uri_parser import _parse_kms_tls_options if TYPE_CHECKING: + from pymongo.pyopenssl_context import SSLContext from pymongo.typings import _AgnosticMongoClient, _DocumentTypeArg @@ -57,6 +59,7 @@ def __init__( crypt_shared_lib_required: bool = False, bypass_query_analysis: bool = False, encrypted_fields_map: Optional[Mapping[str, Any]] = None, + key_expiration_ms: Optional[int] = None, ) -> None: """Options to configure automatic client-side field level encryption. @@ -191,9 +194,14 @@ def __init__( ] } } + :param key_expiration_ms: The cache expiration time for data encryption keys. + Defaults to ``None`` which defers to libmongocrypt's default which is currently 60000. + Set to 0 to disable key expiration. + .. versionchanged:: 4.12 + Added the `key_expiration_ms` parameter. .. versionchanged:: 4.2 - Added `encrypted_fields_map` `crypt_shared_lib_path`, `crypt_shared_lib_required`, + Added the `encrypted_fields_map`, `crypt_shared_lib_path`, `crypt_shared_lib_required`, and `bypass_query_analysis` parameters. .. versionchanged:: 4.0 @@ -210,7 +218,6 @@ def __init__( if encrypted_fields_map: validate_is_mapping("encrypted_fields_map", encrypted_fields_map) self._encrypted_fields_map = encrypted_fields_map - self._bypass_query_analysis = bypass_query_analysis self._crypt_shared_lib_path = crypt_shared_lib_path self._crypt_shared_lib_required = crypt_shared_lib_required self._kms_providers = kms_providers @@ -225,12 +232,27 @@ def __init__( mongocryptd_spawn_args = ["--idleShutdownTimeoutSecs=60"] self._mongocryptd_spawn_args = mongocryptd_spawn_args if not isinstance(self._mongocryptd_spawn_args, list): - raise TypeError("mongocryptd_spawn_args must be a list") + raise TypeError( + f"mongocryptd_spawn_args must be a list, not {type(self._mongocryptd_spawn_args)}" + ) if not any("idleShutdownTimeoutSecs" in s for s in self._mongocryptd_spawn_args): self._mongocryptd_spawn_args.append("--idleShutdownTimeoutSecs=60") # Maps KMS provider name to a SSLContext. - self._kms_ssl_contexts = _parse_kms_tls_options(kms_tls_options) + self._kms_tls_options = kms_tls_options + self._sync_kms_ssl_contexts: Optional[dict[str, SSLContext]] = None + self._async_kms_ssl_contexts: Optional[dict[str, SSLContext]] = None self._bypass_query_analysis = bypass_query_analysis + self._key_expiration_ms = key_expiration_ms + + def _kms_ssl_contexts(self, is_sync: bool) -> dict[str, SSLContext]: + if is_sync: + if self._sync_kms_ssl_contexts is None: + self._sync_kms_ssl_contexts = _parse_kms_tls_options(self._kms_tls_options, True) + return self._sync_kms_ssl_contexts + else: + if self._async_kms_ssl_contexts is None: + self._async_kms_ssl_contexts = _parse_kms_tls_options(self._kms_tls_options, False) + return self._async_kms_ssl_contexts class RangeOpts: diff --git a/pymongo/errors.py b/pymongo/errors.py index 2cd1081e3b..794b5a9398 100644 --- a/pymongo/errors.py +++ b/pymongo/errors.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/event_loggers.py b/pymongo/event_loggers.py index 86b53c6376..80acaa10c0 100644 --- a/pymongo/event_loggers.py +++ b/pymongo/event_loggers.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/hello.py b/pymongo/hello.py index c30b825e19..1eb40ed929 100644 --- a/pymongo/hello.py +++ b/pymongo/hello.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/helpers_shared.py b/pymongo/helpers_shared.py index 83ea2ddf78..a664e87a69 100644 --- a/pymongo/helpers_shared.py +++ b/pymongo/helpers_shared.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -122,7 +122,7 @@ def _index_list( """ if direction is not None: if not isinstance(key_or_list, str): - raise TypeError("Expected a string and a direction") + raise TypeError(f"Expected a string and a direction, not {type(key_or_list)}") return [(key_or_list, direction)] else: if isinstance(key_or_list, str): @@ -132,7 +132,9 @@ def _index_list( elif isinstance(key_or_list, abc.Mapping): return list(key_or_list.items()) elif not isinstance(key_or_list, (list, tuple)): - raise TypeError("if no direction is specified, key_or_list must be an instance of list") + raise TypeError( + f"if no direction is specified, key_or_list must be an instance of list, not {type(key_or_list)}" + ) values: list[tuple[str, int]] = [] for item in key_or_list: if isinstance(item, str): @@ -172,11 +174,12 @@ def _index_document(index_list: _IndexList) -> dict[str, Any]: def _validate_index_key_pair(key: Any, value: Any) -> None: if not isinstance(key, str): - raise TypeError("first item in each key pair must be an instance of str") + raise TypeError(f"first item in each key pair must be an instance of str, not {type(key)}") if not isinstance(value, (str, int, abc.Mapping)): raise TypeError( "second item in each key pair must be 1, -1, " "'2d', or another valid MongoDB index specifier." + f", not {type(value)}" ) diff --git a/pymongo/lock.py b/pymongo/lock.py index 6bf7138017..ad990fce3f 100644 --- a/pymongo/lock.py +++ b/pymongo/lock.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/logger.py b/pymongo/logger.py index 2ff35328b4..1b3fe43b86 100644 --- a/pymongo/logger.py +++ b/pymongo/logger.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -96,6 +96,14 @@ class _SDAMStatusMessage(str, enum.Enum): } +def _log_client_error() -> None: + # This is called from a daemon thread so check for None to account for interpreter shutdown. + logger = _CLIENT_LOGGER + if logger: + # logger.exception includes the full traceback. + logger.exception("MongoClient background task encountered an error:") + + def _debug_log(logger: logging.Logger, **fields: Any) -> None: logger.debug(LogMessage(**fields)) diff --git a/pymongo/max_staleness_selectors.py b/pymongo/max_staleness_selectors.py index 89bfa65281..5f1e404720 100644 --- a/pymongo/max_staleness_selectors.py +++ b/pymongo/max_staleness_selectors.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/message.py b/pymongo/message.py index 10c9edb5cd..d51c77a174 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -105,7 +105,7 @@ "insert": "documents", "update": "updates", "delete": "deletes", - "bulkWrite": "bulkWrite", + "bulkWrite": "ops", } _UNICODE_REPLACE_CODEC_OPTIONS: CodecOptions[Mapping[str, Any]] = CodecOptions( diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index a815cbc8a9..778abe27ef 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/monitoring.py b/pymongo/monitoring.py index 96f88597d2..101a8fbc37 100644 --- a/pymongo/monitoring.py +++ b/pymongo/monitoring.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -472,14 +472,15 @@ def _validate_event_listeners( ) -> Sequence[_EventListeners]: """Validate event listeners""" if not isinstance(listeners, abc.Sequence): - raise TypeError(f"{option} must be a list or tuple") + raise TypeError(f"{option} must be a list or tuple, not {type(listeners)}") for listener in listeners: if not isinstance(listener, _EventListener): raise TypeError( f"Listeners for {option} must be either a " "CommandListener, ServerHeartbeatListener, " "ServerListener, TopologyListener, or " - "ConnectionPoolListener." + "ConnectionPoolListener," + f"not {type(listener)}" ) return listeners @@ -496,7 +497,8 @@ def register(listener: _EventListener) -> None: f"Listeners for {listener} must be either a " "CommandListener, ServerHeartbeatListener, " "ServerListener, TopologyListener, or " - "ConnectionPoolListener." + "ConnectionPoolListener," + f"not {type(listener)}" ) if isinstance(listener, CommandListener): _LISTENERS.command_listeners.append(listener) diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index 11c66bf16e..3fa180bf7a 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -16,21 +16,26 @@ from __future__ import annotations import asyncio +import collections import errno import socket import struct import sys import time -from asyncio import AbstractEventLoop, Future +from asyncio import AbstractEventLoop, BaseTransport, BufferedProtocol, Future, Transport from typing import ( TYPE_CHECKING, + Any, Optional, Union, ) from pymongo import _csot, ssl_support from pymongo._asyncio_task import create_task -from pymongo.errors import _OperationCancelled +from pymongo.common import MAX_MESSAGE_SIZE +from pymongo.compression_support import decompress +from pymongo.errors import ProtocolError, _OperationCancelled +from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply from pymongo.socket_checker import _errno_from_exception try: @@ -41,22 +46,18 @@ _HAVE_SSL = False try: - from pymongo.pyopenssl_context import ( - BLOCKING_IO_LOOKUP_ERROR, - BLOCKING_IO_READ_ERROR, - BLOCKING_IO_WRITE_ERROR, - _sslConn, - ) + from pymongo.pyopenssl_context import _sslConn _HAVE_PYOPENSSL = True except ImportError: _HAVE_PYOPENSSL = False - _sslConn = SSLSocket # type: ignore - from pymongo.ssl_support import ( # type: ignore[assignment] - BLOCKING_IO_LOOKUP_ERROR, - BLOCKING_IO_READ_ERROR, - BLOCKING_IO_WRITE_ERROR, - ) + _sslConn = SSLSocket # type: ignore[assignment, misc] + +from pymongo.ssl_support import ( + BLOCKING_IO_LOOKUP_ERROR, + BLOCKING_IO_READ_ERROR, + BLOCKING_IO_WRITE_ERROR, +) if TYPE_CHECKING: from pymongo.asynchronous.pool import AsyncConnection @@ -66,16 +67,18 @@ _UNPACK_COMPRESSION_HEADER = struct.Struct(" None: +# These socket-based I/O methods are for KMS requests and any other network operations that do not use +# the MongoDB wire protocol +async def async_socket_sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None: timeout = sock.gettimeout() sock.settimeout(0.0) loop = asyncio.get_running_loop() try: if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)): - await asyncio.wait_for(_async_sendall_ssl(sock, buf, loop), timeout=timeout) + await asyncio.wait_for(_async_socket_sendall_ssl(sock, buf, loop), timeout=timeout) else: await asyncio.wait_for(loop.sock_sendall(sock, buf), timeout=timeout) # type: ignore[arg-type] except asyncio.TimeoutError as exc: @@ -87,7 +90,7 @@ async def async_sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> Non if sys.platform != "win32": - async def _async_sendall_ssl( + async def _async_socket_sendall_ssl( sock: Union[socket.socket, _sslConn], buf: bytes, loop: AbstractEventLoop ) -> None: view = memoryview(buf) @@ -130,7 +133,7 @@ def _is_ready(fut: Future) -> None: loop.remove_reader(fd) loop.remove_writer(fd) - async def _async_receive_ssl( + async def _async_socket_receive_ssl( conn: _sslConn, length: int, loop: AbstractEventLoop, once: Optional[bool] = False ) -> memoryview: mv = memoryview(bytearray(length)) @@ -184,7 +187,7 @@ def _is_ready(fut: Future) -> None: # The default Windows asyncio event loop does not support loop.add_reader/add_writer: # https://docs.python.org/3/library/asyncio-platforms.html#asyncio-platform-support # Note: In PYTHON-4493 we plan to replace this code with asyncio streams. - async def _async_sendall_ssl( + async def _async_socket_sendall_ssl( sock: Union[socket.socket, _sslConn], buf: bytes, dummy: AbstractEventLoop ) -> None: view = memoryview(buf) @@ -205,7 +208,7 @@ async def _async_sendall_ssl( backoff = min(backoff * 2, 0.512) total_sent += sent - async def _async_receive_ssl( + async def _async_socket_receive_ssl( conn: _sslConn, length: int, dummy: AbstractEventLoop, once: Optional[bool] = False ) -> memoryview: mv = memoryview(bytearray(length)) @@ -244,52 +247,6 @@ async def _poll_cancellation(conn: AsyncConnection) -> None: await asyncio.sleep(_POLL_TIMEOUT) -async def async_receive_data( - conn: AsyncConnection, length: int, deadline: Optional[float] -) -> memoryview: - sock = conn.conn - sock_timeout = sock.gettimeout() - timeout: Optional[Union[float, int]] - if deadline: - # When the timeout has expired perform one final check to - # see if the socket is readable. This helps avoid spurious - # timeouts on AWS Lambda and other FaaS environments. - timeout = max(deadline - time.monotonic(), 0) - else: - timeout = sock_timeout - - sock.settimeout(0.0) - loop = asyncio.get_running_loop() - cancellation_task = create_task(_poll_cancellation(conn)) - try: - if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)): - read_task = create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type] - else: - read_task = create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type] - tasks = [read_task, cancellation_task] - try: - done, pending = await asyncio.wait( - tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED - ) - for task in pending: - task.cancel() - if pending: - await asyncio.wait(pending) - if len(done) == 0: - raise socket.timeout("timed out") - if read_task in done: - return read_task.result() - raise _OperationCancelled("operation cancelled") - except asyncio.CancelledError: - for task in tasks: - task.cancel() - await asyncio.wait(tasks) - raise - - finally: - sock.settimeout(sock_timeout) - - async def async_receive_data_socket( sock: Union[socket.socket, _sslConn], length: int ) -> memoryview: @@ -301,18 +258,23 @@ async def async_receive_data_socket( try: if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)): return await asyncio.wait_for( - _async_receive_ssl(sock, length, loop, once=True), # type: ignore[arg-type] + _async_socket_receive_ssl(sock, length, loop, once=True), # type: ignore[arg-type] timeout=timeout, ) else: - return await asyncio.wait_for(_async_receive(sock, length, loop), timeout=timeout) # type: ignore[arg-type] + return await asyncio.wait_for( + _async_socket_receive(sock, length, loop), # type: ignore[arg-type] + timeout=timeout, + ) except asyncio.TimeoutError as err: raise socket.timeout("timed out") from err finally: sock.settimeout(sock_timeout) -async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLoop) -> memoryview: +async def _async_socket_receive( + conn: socket.socket, length: int, loop: AbstractEventLoop +) -> memoryview: mv = memoryview(bytearray(length)) bytes_read = 0 while bytes_read < length: @@ -328,7 +290,7 @@ async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLo def wait_for_read(conn: Connection, deadline: Optional[float]) -> None: """Block until at least one byte is read, or a timeout, or a cancel.""" - sock = conn.conn + sock = conn.conn.sock timed_out = False # Check if the connection's socket has been manually closed if sock.fileno() == -1: @@ -413,3 +375,403 @@ def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> me conn.set_conn_timeout(orig_timeout) return mv + + +class NetworkingInterfaceBase: + def __init__(self, conn: Any): + self.conn = conn + + @property + def gettimeout(self) -> Any: + raise NotImplementedError + + def settimeout(self, timeout: float | None) -> None: + raise NotImplementedError + + def close(self) -> Any: + raise NotImplementedError + + def is_closing(self) -> bool: + raise NotImplementedError + + @property + def get_conn(self) -> Any: + raise NotImplementedError + + @property + def sock(self) -> Any: + raise NotImplementedError + + +class AsyncNetworkingInterface(NetworkingInterfaceBase): + def __init__(self, conn: tuple[Transport, PyMongoProtocol]): + super().__init__(conn) + + @property + def gettimeout(self) -> float | None: + return self.conn[1].gettimeout + + def settimeout(self, timeout: float | None) -> None: + self.conn[1].settimeout(timeout) + + async def close(self) -> None: + self.conn[1].close() + await self.conn[1].wait_closed() + + def is_closing(self) -> bool: + return self.conn[0].is_closing() + + @property + def get_conn(self) -> PyMongoProtocol: + return self.conn[1] + + @property + def sock(self) -> socket.socket: + return self.conn[0].get_extra_info("socket") + + +class NetworkingInterface(NetworkingInterfaceBase): + def __init__(self, conn: Union[socket.socket, _sslConn]): + super().__init__(conn) + + def gettimeout(self) -> float | None: + return self.conn.gettimeout() + + def settimeout(self, timeout: float | None) -> None: + self.conn.settimeout(timeout) + + def close(self) -> None: + self.conn.close() + + def is_closing(self) -> bool: + return self.conn.is_closing() + + @property + def get_conn(self) -> Union[socket.socket, _sslConn]: + return self.conn + + @property + def sock(self) -> Union[socket.socket, _sslConn]: + return self.conn + + def fileno(self) -> int: + return self.conn.fileno() + + def recv_into(self, buffer: bytes) -> int: + return self.conn.recv_into(buffer) + + +class PyMongoProtocol(BufferedProtocol): + def __init__(self, timeout: Optional[float] = None): + self.transport: Transport = None # type: ignore[assignment] + # Each message is reader in 2-3 parts: header, compression header, and message body + # The message buffer is allocated after the header is read. + self._header = memoryview(bytearray(16)) + self._header_index = 0 + self._compression_header = memoryview(bytearray(9)) + self._compression_index = 0 + self._message: Optional[memoryview] = None + self._message_index = 0 + # State. TODO: replace booleans with an enum? + self._expecting_header = True + self._expecting_compression = False + self._message_size = 0 + self._op_code = 0 + self._connection_lost = False + self._read_waiter: Optional[Future] = None + self._timeout = timeout + self._is_compressed = False + self._compressor_id: Optional[int] = None + self._max_message_size = MAX_MESSAGE_SIZE + self._response_to: Optional[int] = None + self._closed = asyncio.get_running_loop().create_future() + self._pending_messages: collections.deque[Future] = collections.deque() + self._done_messages: collections.deque[Future] = collections.deque() + + def settimeout(self, timeout: float | None) -> None: + self._timeout = timeout + + @property + def gettimeout(self) -> float | None: + """The configured timeout for the socket that underlies our protocol pair.""" + return self._timeout + + def connection_made(self, transport: BaseTransport) -> None: + """Called exactly once when a connection is made. + The transport argument is the transport representing the write side of the connection. + """ + self.transport = transport # type: ignore[assignment] + self.transport.set_write_buffer_limits(MAX_MESSAGE_SIZE, MAX_MESSAGE_SIZE) + + async def write(self, message: bytes) -> None: + """Write a message to this connection's transport.""" + if self.transport.is_closing(): + raise OSError("Connection is closed") + self.transport.write(message) + self.transport.resume_reading() + + async def read(self, request_id: Optional[int], max_message_size: int) -> tuple[bytes, int]: + """Read a single MongoDB Wire Protocol message from this connection.""" + if self.transport: + try: + self.transport.resume_reading() + # Known bug in SSL Protocols, fixed in Python 3.11: https://github.com/python/cpython/issues/89322 + except AttributeError: + raise OSError("connection is already closed") from None + self._max_message_size = max_message_size + if self._done_messages: + message = await self._done_messages.popleft() + else: + if self.transport and self.transport.is_closing(): + raise OSError("connection is already closed") + read_waiter = asyncio.get_running_loop().create_future() + self._pending_messages.append(read_waiter) + try: + message = await read_waiter + finally: + if read_waiter in self._done_messages: + self._done_messages.remove(read_waiter) + if message: + op_code, compressor_id, response_to, data = message + # No request_id for exhaust cursor "getMore". + if request_id is not None: + if request_id != response_to: + raise ProtocolError( + f"Got response id {response_to!r} but expected {request_id!r}" + ) + if compressor_id is not None: + data = decompress(data, compressor_id) + return data, op_code + raise OSError("connection closed") + + def get_buffer(self, sizehint: int) -> memoryview: + """Called to allocate a new receive buffer. + The asyncio loop calls this method expecting to receive a non-empty buffer to fill with data. + If any data does not fit into the returned buffer, this method will be called again until + either no data remains or an empty buffer is returned. + """ + # Due to a bug, Python <=3.11 will call get_buffer() even after we raise + # ProtocolError in buffer_updated() and call connection_lost(). We allocate + # a temp buffer to drain the waiting data. + if self._connection_lost: + if not self._message: + self._message = memoryview(bytearray(2**14)) + return self._message + # TODO: optimize this by caching pointers to the buffers. + # return self._buffer[self._index:] + if self._expecting_header: + return self._header[self._header_index :] + if self._expecting_compression: + return self._compression_header[self._compression_index :] + return self._message[self._message_index :] # type: ignore[index] + + def buffer_updated(self, nbytes: int) -> None: + """Called when the buffer was updated with the received data""" + # Wrote 0 bytes into a non-empty buffer, signal connection closed + if nbytes == 0: + self.close(OSError("connection closed")) + return + if self._connection_lost: + return + if self._expecting_header: + self._header_index += nbytes + if self._header_index >= 16: + self._expecting_header = False + try: + ( + self._message_size, + self._op_code, + self._response_to, + self._expecting_compression, + ) = self.process_header() + except ProtocolError as exc: + self.close(exc) + return + self._message = memoryview(bytearray(self._message_size)) + return + if self._expecting_compression: + self._compression_index += nbytes + if self._compression_index >= 9: + self._expecting_compression = False + self._op_code, self._compressor_id = self.process_compression_header() + return + + self._message_index += nbytes + if self._message_index >= self._message_size: + self._expecting_header = True + # Pause reading to avoid storing an arbitrary number of messages in memory. + self.transport.pause_reading() + if self._pending_messages: + result = self._pending_messages.popleft() + else: + result = asyncio.get_running_loop().create_future() + # Future has been cancelled, close this connection + if result.done(): + self.close(None) + return + # Necessary values to reconstruct and verify message + result.set_result( + (self._op_code, self._compressor_id, self._response_to, self._message) + ) + self._done_messages.append(result) + # Reset internal state to expect a new message + self._header_index = 0 + self._compression_index = 0 + self._message_index = 0 + self._message_size = 0 + self._message = None + self._op_code = 0 + self._compressor_id = None + self._response_to = None + + def process_header(self) -> tuple[int, int, int, bool]: + """Unpack a MongoDB Wire Protocol header.""" + length, _, response_to, op_code = _UNPACK_HEADER(self._header) + expecting_compression = False + if op_code == 2012: # OP_COMPRESSED + if length <= 25: + raise ProtocolError( + f"Message length ({length!r}) not longer than standard OP_COMPRESSED message header size (25)" + ) + expecting_compression = True + length -= 9 + if length <= 16: + raise ProtocolError( + f"Message length ({length!r}) not longer than standard message header size (16)" + ) + if length > self._max_message_size: + raise ProtocolError( + f"Message length ({length!r}) is larger than server max " + f"message size ({self._max_message_size!r})" + ) + + return length - 16, op_code, response_to, expecting_compression + + def process_compression_header(self) -> tuple[int, int]: + """Unpack a MongoDB Wire Protocol compression header.""" + op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(self._compression_header) + return op_code, compressor_id + + def _resolve_pending_messages(self, exc: Optional[Exception] = None) -> None: + pending = list(self._pending_messages) + for msg in pending: + if not msg.done(): + if exc is None: + msg.set_result(None) + else: + msg.set_exception(exc) + self._done_messages.append(msg) + + def close(self, exc: Optional[Exception] = None) -> None: + self.transport.abort() + self._resolve_pending_messages(exc) + self._connection_lost = True + + def connection_lost(self, exc: Optional[Exception] = None) -> None: + self._resolve_pending_messages(exc) + if not self._closed.done(): + self._closed.set_result(None) + + async def wait_closed(self) -> None: + await self._closed + + +async def async_sendall(conn: PyMongoProtocol, buf: bytes) -> None: + try: + await asyncio.wait_for(conn.write(buf), timeout=conn.gettimeout) + except asyncio.TimeoutError as exc: + # Convert the asyncio.wait_for timeout error to socket.timeout which pool.py understands. + raise socket.timeout("timed out") from exc + + +async def async_receive_message( + conn: AsyncConnection, + request_id: Optional[int], + max_message_size: int = MAX_MESSAGE_SIZE, +) -> Union[_OpReply, _OpMsg]: + """Receive a raw BSON message or raise socket.error.""" + timeout: Optional[Union[float, int]] + timeout = conn.conn.gettimeout + if _csot.get_timeout(): + deadline = _csot.get_deadline() + else: + if timeout: + deadline = time.monotonic() + timeout + else: + deadline = None + if deadline: + # When the timeout has expired perform one final check to + # see if the socket is readable. This helps avoid spurious + # timeouts on AWS Lambda and other FaaS environments. + timeout = max(deadline - time.monotonic(), 0) + + cancellation_task = create_task(_poll_cancellation(conn)) + read_task = create_task(conn.conn.get_conn.read(request_id, max_message_size)) + tasks = [read_task, cancellation_task] + try: + done, pending = await asyncio.wait( + tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED + ) + for task in pending: + task.cancel() + if pending: + await asyncio.wait(pending) + if len(done) == 0: + raise socket.timeout("timed out") + if read_task in done: + data, op_code = read_task.result() + try: + unpack_reply = _UNPACK_REPLY[op_code] + except KeyError: + raise ProtocolError( + f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}" + ) from None + return unpack_reply(data) + raise _OperationCancelled("operation cancelled") + except asyncio.CancelledError: + for task in tasks: + task.cancel() + await asyncio.wait(tasks) + raise + + +def receive_message( + conn: Connection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE +) -> Union[_OpReply, _OpMsg]: + """Receive a raw BSON message or raise socket.error.""" + if _csot.get_timeout(): + deadline = _csot.get_deadline() + else: + timeout = conn.conn.gettimeout() + if timeout: + deadline = time.monotonic() + timeout + else: + deadline = None + # Ignore the response's request id. + length, _, response_to, op_code = _UNPACK_HEADER(receive_data(conn, 16, deadline)) + # No request_id for exhaust cursor "getMore". + if request_id is not None: + if request_id != response_to: + raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}") + if length <= 16: + raise ProtocolError( + f"Message length ({length!r}) not longer than standard message header size (16)" + ) + if length > max_message_size: + raise ProtocolError( + f"Message length ({length!r}) is larger than server max " + f"message size ({max_message_size!r})" + ) + if op_code == 2012: + op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(receive_data(conn, 9, deadline)) + data = decompress(receive_data(conn, length - 25, deadline), compressor_id) + else: + data = receive_data(conn, length - 16, deadline) + + try: + unpack_reply = _UNPACK_REPLY[op_code] + except KeyError: + raise ProtocolError( + f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}" + ) from None + return unpack_reply(data) diff --git a/pymongo/ocsp_cache.py b/pymongo/ocsp_cache.py index 3facefe350..2df232848f 100644 --- a/pymongo/ocsp_cache.py +++ b/pymongo/ocsp_cache.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/ocsp_support.py b/pymongo/ocsp_support.py index ee359b71c2..8322f821fb 100644 --- a/pymongo/ocsp_support.py +++ b/pymongo/ocsp_support.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/operations.py b/pymongo/operations.py index 482ab68003..300f1ba123 100644 --- a/pymongo/operations.py +++ b/pymongo/operations.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/periodic_executor.py b/pymongo/periodic_executor.py index 9b10f6e7e3..323debdce2 100644 --- a/pymongo/periodic_executor.py +++ b/pymongo/periodic_executor.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -75,6 +75,8 @@ def close(self, dummy: Any = None) -> None: callback; see monitor.py. """ self._stopped = True + if self._task is not None: + self._task.cancel() async def join(self, timeout: Optional[int] = None) -> None: if self._task is not None: @@ -98,6 +100,7 @@ async def _run(self) -> None: if not await self._target(): self._stopped = True break + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException: self._stopped = True raise @@ -230,6 +233,7 @@ def _run(self) -> None: if not self._target(): self._stopped = True break + # Catch KeyboardInterrupt, etc. and cleanup. except BaseException: with self._lock: self._stopped = True diff --git a/pymongo/pool.py b/pymongo/pool.py index fbbb70fc68..456ff3df0a 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/pool_options.py b/pymongo/pool_options.py index 038dbb3b5d..a2e309cc56 100644 --- a/pymongo/pool_options.py +++ b/pymongo/pool_options.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/pool_shared.py b/pymongo/pool_shared.py new file mode 100644 index 0000000000..308ecef349 --- /dev/null +++ b/pymongo/pool_shared.py @@ -0,0 +1,539 @@ +# Copyright 2025-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pool utilities and shared helper methods.""" +from __future__ import annotations + +import asyncio +import functools +import socket +import ssl +import sys +from typing import ( + TYPE_CHECKING, + Any, + NoReturn, + Optional, + Union, +) + +from pymongo import _csot +from pymongo.asynchronous.helpers import _getaddrinfo +from pymongo.errors import ( # type:ignore[attr-defined] + AutoReconnect, + ConnectionFailure, + NetworkTimeout, + _CertificateError, +) +from pymongo.network_layer import AsyncNetworkingInterface, NetworkingInterface, PyMongoProtocol +from pymongo.pool_options import PoolOptions +from pymongo.ssl_support import PYSSLError, SSLError, _has_sni + +SSLErrors = (PYSSLError, SSLError) +if TYPE_CHECKING: + from pymongo.pyopenssl_context import _sslConn + from pymongo.typings import _Address + +try: + from fcntl import F_GETFD, F_SETFD, FD_CLOEXEC, fcntl + + def _set_non_inheritable_non_atomic(fd: int) -> None: + """Set the close-on-exec flag on the given file descriptor.""" + flags = fcntl(fd, F_GETFD) + fcntl(fd, F_SETFD, flags | FD_CLOEXEC) + +except ImportError: + # Windows, various platforms we don't claim to support + # (Jython, IronPython, ..), systems that don't provide + # everything we need from fcntl, etc. + def _set_non_inheritable_non_atomic(fd: int) -> None: # noqa: ARG001 + """Dummy function for platforms that don't provide fcntl.""" + + +_MAX_TCP_KEEPIDLE = 120 +_MAX_TCP_KEEPINTVL = 10 +_MAX_TCP_KEEPCNT = 9 + +if sys.platform == "win32": + try: + import _winreg as winreg + except ImportError: + import winreg + + def _query(key, name, default): + try: + value, _ = winreg.QueryValueEx(key, name) + # Ensure the value is a number or raise ValueError. + return int(value) + except (OSError, ValueError): + # QueryValueEx raises OSError when the key does not exist (i.e. + # the system is using the Windows default value). + return default + + try: + with winreg.OpenKey( + winreg.HKEY_LOCAL_MACHINE, r"SYSTEM\CurrentControlSet\Services\Tcpip\Parameters" + ) as key: + _WINDOWS_TCP_IDLE_MS = _query(key, "KeepAliveTime", 7200000) + _WINDOWS_TCP_INTERVAL_MS = _query(key, "KeepAliveInterval", 1000) + except OSError: + # We could not check the default values because winreg.OpenKey failed. + # Assume the system is using the default values. + _WINDOWS_TCP_IDLE_MS = 7200000 + _WINDOWS_TCP_INTERVAL_MS = 1000 + + def _set_keepalive_times(sock): + idle_ms = min(_WINDOWS_TCP_IDLE_MS, _MAX_TCP_KEEPIDLE * 1000) + interval_ms = min(_WINDOWS_TCP_INTERVAL_MS, _MAX_TCP_KEEPINTVL * 1000) + if idle_ms < _WINDOWS_TCP_IDLE_MS or interval_ms < _WINDOWS_TCP_INTERVAL_MS: + sock.ioctl(socket.SIO_KEEPALIVE_VALS, (1, idle_ms, interval_ms)) + +else: + + def _set_tcp_option(sock: socket.socket, tcp_option: str, max_value: int) -> None: + if hasattr(socket, tcp_option): + sockopt = getattr(socket, tcp_option) + try: + # PYTHON-1350 - NetBSD doesn't implement getsockopt for + # TCP_KEEPIDLE and friends. Don't attempt to set the + # values there. + default = sock.getsockopt(socket.IPPROTO_TCP, sockopt) + if default > max_value: + sock.setsockopt(socket.IPPROTO_TCP, sockopt, max_value) + except OSError: + pass + + def _set_keepalive_times(sock: socket.socket) -> None: + _set_tcp_option(sock, "TCP_KEEPIDLE", _MAX_TCP_KEEPIDLE) + _set_tcp_option(sock, "TCP_KEEPINTVL", _MAX_TCP_KEEPINTVL) + _set_tcp_option(sock, "TCP_KEEPCNT", _MAX_TCP_KEEPCNT) + + +def _raise_connection_failure( + address: Any, + error: Exception, + msg_prefix: Optional[str] = None, + timeout_details: Optional[dict[str, float]] = None, +) -> NoReturn: + """Convert a socket.error to ConnectionFailure and raise it.""" + host, port = address + # If connecting to a Unix socket, port will be None. + if port is not None: + msg = "%s:%d: %s" % (host, port, error) + else: + msg = f"{host}: {error}" + if msg_prefix: + msg = msg_prefix + msg + if "configured timeouts" not in msg: + msg += format_timeout_details(timeout_details) + if isinstance(error, socket.timeout): + raise NetworkTimeout(msg) from error + elif isinstance(error, SSLErrors) and "timed out" in str(error): + # Eventlet does not distinguish TLS network timeouts from other + # SSLErrors (https://github.com/eventlet/eventlet/issues/692). + # Luckily, we can work around this limitation because the phrase + # 'timed out' appears in all the timeout related SSLErrors raised. + raise NetworkTimeout(msg) from error + else: + raise AutoReconnect(msg) from error + + +def _get_timeout_details(options: PoolOptions) -> dict[str, float]: + details = {} + timeout = _csot.get_timeout() + socket_timeout = options.socket_timeout + connect_timeout = options.connect_timeout + if timeout: + details["timeoutMS"] = timeout * 1000 + if socket_timeout and not timeout: + details["socketTimeoutMS"] = socket_timeout * 1000 + if connect_timeout: + details["connectTimeoutMS"] = connect_timeout * 1000 + return details + + +def format_timeout_details(details: Optional[dict[str, float]]) -> str: + result = "" + if details: + result += " (configured timeouts:" + for timeout in ["socketTimeoutMS", "timeoutMS", "connectTimeoutMS"]: + if timeout in details: + result += f" {timeout}: {details[timeout]}ms," + result = result[:-1] + result += ")" + return result + + +class _CancellationContext: + def __init__(self) -> None: + self._cancelled = False + + def cancel(self) -> None: + """Cancel this context.""" + self._cancelled = True + + @property + def cancelled(self) -> bool: + """Was cancel called?""" + return self._cancelled + + +async def _async_create_connection(address: _Address, options: PoolOptions) -> socket.socket: + """Given (host, port) and PoolOptions, connect and return a raw socket object. + + Can raise socket.error. + + This is a modified version of create_connection from CPython >= 2.7. + """ + host, port = address + + # Check if dealing with a unix domain socket + if host.endswith(".sock"): + if not hasattr(socket, "AF_UNIX"): + raise ConnectionFailure("UNIX-sockets are not supported on this system") + sock = socket.socket(socket.AF_UNIX) + # SOCK_CLOEXEC not supported for Unix sockets. + _set_non_inheritable_non_atomic(sock.fileno()) + try: + sock.connect(host) + return sock + except OSError: + sock.close() + raise + + # Don't try IPv6 if we don't support it. Also skip it if host + # is 'localhost' (::1 is fine). Avoids slow connect issues + # like PYTHON-356. + family = socket.AF_INET + if socket.has_ipv6 and host != "localhost": + family = socket.AF_UNSPEC + + err = None + for res in await _getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): + af, socktype, proto, dummy, sa = res + # SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited + # number of platforms (newer Linux and *BSD). Starting with CPython 3.4 + # all file descriptors are created non-inheritable. See PEP 446. + try: + sock = socket.socket(af, socktype | getattr(socket, "SOCK_CLOEXEC", 0), proto) + except OSError: + # Can SOCK_CLOEXEC be defined even if the kernel doesn't support + # it? + sock = socket.socket(af, socktype, proto) + # Fallback when SOCK_CLOEXEC isn't available. + _set_non_inheritable_non_atomic(sock.fileno()) + try: + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + # CSOT: apply timeout to socket connect. + timeout = _csot.remaining() + if timeout is None: + timeout = options.connect_timeout + elif timeout <= 0: + raise socket.timeout("timed out") + sock.settimeout(timeout) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True) + _set_keepalive_times(sock) + sock.connect(sa) + return sock + except OSError as e: + err = e + sock.close() + + if err is not None: + raise err + else: + # This likely means we tried to connect to an IPv6 only + # host with an OS/kernel or Python interpreter that doesn't + # support IPv6. The test case is Jython2.5.1 which doesn't + # support IPv6 at all. + raise OSError("getaddrinfo failed") + + +async def _async_configured_socket( + address: _Address, options: PoolOptions +) -> Union[socket.socket, _sslConn]: + """Given (host, port) and PoolOptions, return a raw configured socket. + + Can raise socket.error, ConnectionFailure, or _CertificateError. + + Sets socket's SSL and timeout options. + """ + sock = await _async_create_connection(address, options) + ssl_context = options._ssl_context + + if ssl_context is None: + sock.settimeout(options.socket_timeout) + return sock + + host = address[0] + try: + # We have to pass hostname / ip address to wrap_socket + # to use SSLContext.check_hostname. + if _has_sni(False): + loop = asyncio.get_running_loop() + ssl_sock = await loop.run_in_executor( + None, + functools.partial(ssl_context.wrap_socket, sock, server_hostname=host), # type: ignore[assignment, misc, unused-ignore] + ) + else: + loop = asyncio.get_running_loop() + ssl_sock = await loop.run_in_executor(None, ssl_context.wrap_socket, sock) # type: ignore[assignment, misc, unused-ignore] + except _CertificateError: + sock.close() + # Raise _CertificateError directly like we do after match_hostname + # below. + raise + except (OSError, *SSLErrors) as exc: + sock.close() + # We raise AutoReconnect for transient and permanent SSL handshake + # failures alike. Permanent handshake failures, like protocol + # mismatch, will be turned into ServerSelectionTimeoutErrors later. + details = _get_timeout_details(options) + _raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details) + if ( + ssl_context.verify_mode + and not ssl_context.check_hostname + and not options.tls_allow_invalid_hostnames + ): + try: + ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined, unused-ignore] + except _CertificateError: + ssl_sock.close() + raise + + ssl_sock.settimeout(options.socket_timeout) + return ssl_sock + + +async def _configured_protocol_interface( + address: _Address, options: PoolOptions +) -> AsyncNetworkingInterface: + """Given (host, port) and PoolOptions, return a configured AsyncNetworkingInterface. + + Can raise socket.error, ConnectionFailure, or _CertificateError. + + Sets protocol's SSL and timeout options. + """ + sock = await _async_create_connection(address, options) + ssl_context = options._ssl_context + timeout = options.socket_timeout + + if ssl_context is None: + return AsyncNetworkingInterface( + await asyncio.get_running_loop().create_connection( + lambda: PyMongoProtocol(timeout=timeout), sock=sock + ) + ) + + host = address[0] + try: + # We have to pass hostname / ip address to wrap_socket + # to use SSLContext.check_hostname. + transport, protocol = await asyncio.get_running_loop().create_connection( # type: ignore[call-overload] + lambda: PyMongoProtocol(timeout=timeout), + sock=sock, + server_hostname=host, + ssl=ssl_context, + ) + except _CertificateError: + # Raise _CertificateError directly like we do after match_hostname + # below. + raise + except (OSError, *SSLErrors) as exc: + # We raise AutoReconnect for transient and permanent SSL handshake + # failures alike. Permanent handshake failures, like protocol + # mismatch, will be turned into ServerSelectionTimeoutErrors later. + details = _get_timeout_details(options) + _raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details) + if ( + ssl_context.verify_mode + and not ssl_context.check_hostname + and not options.tls_allow_invalid_hostnames + ): + try: + ssl.match_hostname(transport.get_extra_info("peercert"), hostname=host) # type:ignore[attr-defined,unused-ignore] + except _CertificateError: + transport.abort() + raise + + return AsyncNetworkingInterface((transport, protocol)) + + +def _create_connection(address: _Address, options: PoolOptions) -> socket.socket: + """Given (host, port) and PoolOptions, connect and return a raw socket object. + + Can raise socket.error. + + This is a modified version of create_connection from CPython >= 2.7. + """ + host, port = address + + # Check if dealing with a unix domain socket + if host.endswith(".sock"): + if not hasattr(socket, "AF_UNIX"): + raise ConnectionFailure("UNIX-sockets are not supported on this system") + sock = socket.socket(socket.AF_UNIX) + # SOCK_CLOEXEC not supported for Unix sockets. + _set_non_inheritable_non_atomic(sock.fileno()) + try: + sock.connect(host) + return sock + except OSError: + sock.close() + raise + + # Don't try IPv6 if we don't support it. Also skip it if host + # is 'localhost' (::1 is fine). Avoids slow connect issues + # like PYTHON-356. + family = socket.AF_INET + if socket.has_ipv6 and host != "localhost": + family = socket.AF_UNSPEC + + err = None + for res in socket.getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined, unused-ignore] + af, socktype, proto, dummy, sa = res + # SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited + # number of platforms (newer Linux and *BSD). Starting with CPython 3.4 + # all file descriptors are created non-inheritable. See PEP 446. + try: + sock = socket.socket(af, socktype | getattr(socket, "SOCK_CLOEXEC", 0), proto) + except OSError: + # Can SOCK_CLOEXEC be defined even if the kernel doesn't support + # it? + sock = socket.socket(af, socktype, proto) + # Fallback when SOCK_CLOEXEC isn't available. + _set_non_inheritable_non_atomic(sock.fileno()) + try: + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + # CSOT: apply timeout to socket connect. + timeout = _csot.remaining() + if timeout is None: + timeout = options.connect_timeout + elif timeout <= 0: + raise socket.timeout("timed out") + sock.settimeout(timeout) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True) + _set_keepalive_times(sock) + sock.connect(sa) + return sock + except OSError as e: + err = e + sock.close() + + if err is not None: + raise err + else: + # This likely means we tried to connect to an IPv6 only + # host with an OS/kernel or Python interpreter that doesn't + # support IPv6. The test case is Jython2.5.1 which doesn't + # support IPv6 at all. + raise OSError("getaddrinfo failed") + + +def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket.socket, _sslConn]: + """Given (host, port) and PoolOptions, return a raw configured socket. + + Can raise socket.error, ConnectionFailure, or _CertificateError. + + Sets socket's SSL and timeout options. + """ + sock = _create_connection(address, options) + ssl_context = options._ssl_context + + if ssl_context is None: + sock.settimeout(options.socket_timeout) + return sock + + host = address[0] + try: + # We have to pass hostname / ip address to wrap_socket + # to use SSLContext.check_hostname. + if _has_sni(True): + ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host) # type: ignore[assignment, misc, unused-ignore] + else: + ssl_sock = ssl_context.wrap_socket(sock) # type: ignore[assignment, misc, unused-ignore] + except _CertificateError: + sock.close() + # Raise _CertificateError directly like we do after match_hostname + # below. + raise + except (OSError, *SSLErrors) as exc: + sock.close() + # We raise AutoReconnect for transient and permanent SSL handshake + # failures alike. Permanent handshake failures, like protocol + # mismatch, will be turned into ServerSelectionTimeoutErrors later. + details = _get_timeout_details(options) + _raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details) + if ( + ssl_context.verify_mode + and not ssl_context.check_hostname + and not options.tls_allow_invalid_hostnames + ): + try: + ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined, unused-ignore] + except _CertificateError: + ssl_sock.close() + raise + + ssl_sock.settimeout(options.socket_timeout) + return ssl_sock + + +def _configured_socket_interface(address: _Address, options: PoolOptions) -> NetworkingInterface: + """Given (host, port) and PoolOptions, return a NetworkingInterface wrapping a configured socket. + + Can raise socket.error, ConnectionFailure, or _CertificateError. + + Sets socket's SSL and timeout options. + """ + sock = _create_connection(address, options) + ssl_context = options._ssl_context + + if ssl_context is None: + sock.settimeout(options.socket_timeout) + return NetworkingInterface(sock) + + host = address[0] + try: + # We have to pass hostname / ip address to wrap_socket + # to use SSLContext.check_hostname. + if _has_sni(True): + ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host) + else: + ssl_sock = ssl_context.wrap_socket(sock) + except _CertificateError: + sock.close() + # Raise _CertificateError directly like we do after match_hostname + # below. + raise + except (OSError, *SSLErrors) as exc: + sock.close() + # We raise AutoReconnect for transient and permanent SSL handshake + # failures alike. Permanent handshake failures, like protocol + # mismatch, will be turned into ServerSelectionTimeoutErrors later. + details = _get_timeout_details(options) + _raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details) + if ( + ssl_context.verify_mode + and not ssl_context.check_hostname + and not options.tls_allow_invalid_hostnames + ): + try: + ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined,unused-ignore] + except _CertificateError: + ssl_sock.close() + raise + + ssl_sock.settimeout(options.socket_timeout) + return NetworkingInterface(ssl_sock) diff --git a/pymongo/pyopenssl_context.py b/pymongo/pyopenssl_context.py index 8c643394b2..0d4f27cf55 100644 --- a/pymongo/pyopenssl_context.py +++ b/pymongo/pyopenssl_context.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -14,10 +14,11 @@ """A CPython compatible SSLContext implementation wrapping PyOpenSSL's context. + +Due to limitations of the CPython asyncio.Protocol implementation for SSL, the async API does not support PyOpenSSL. """ from __future__ import annotations -import asyncio import socket as _socket import ssl as _stdlibssl import sys as _sys @@ -109,15 +110,12 @@ def __init__( ctx: _SSL.Context, sock: Optional[_socket.socket], suppress_ragged_eofs: bool, - is_async: bool = False, ): self.socket_checker = _SocketChecker() self.suppress_ragged_eofs = suppress_ragged_eofs super().__init__(ctx, sock) - self._is_async = is_async def _call(self, call: Callable[..., _T], *args: Any, **kwargs: Any) -> _T: - is_async = kwargs.pop("allow_async", True) and self._is_async timeout = self.gettimeout() if timeout: start = _time.monotonic() @@ -126,7 +124,7 @@ def _call(self, call: Callable[..., _T], *args: Any, **kwargs: Any) -> _T: return call(*args, **kwargs) except BLOCKING_IO_ERRORS as exc: # Do not retry if the connection is in non-blocking mode. - if is_async or timeout == 0: + if timeout == 0: raise exc # Check for closed socket. if self.fileno() == -1: @@ -148,7 +146,6 @@ def _call(self, call: Callable[..., _T], *args: Any, **kwargs: Any) -> _T: continue def do_handshake(self, *args: Any, **kwargs: Any) -> None: - kwargs["allow_async"] = False return self._call(super().do_handshake, *args, **kwargs) def recv(self, *args: Any, **kwargs: Any) -> bytes: @@ -379,58 +376,6 @@ def set_default_verify_paths(self) -> None: # but not that same as CPython's. self._ctx.set_default_verify_paths() - async def a_wrap_socket( - self, - sock: _socket.socket, - server_side: bool = False, - do_handshake_on_connect: bool = True, - suppress_ragged_eofs: bool = True, - server_hostname: Optional[str] = None, - session: Optional[_SSL.Session] = None, - ) -> _sslConn: - """Wrap an existing Python socket connection and return a TLS socket - object. - """ - ssl_conn = _sslConn(self._ctx, sock, suppress_ragged_eofs, True) - loop = asyncio.get_running_loop() - if session: - ssl_conn.set_session(session) - if server_side is True: - ssl_conn.set_accept_state() - else: - # SNI - if server_hostname and not _is_ip_address(server_hostname): - # XXX: Do this in a callback registered with - # SSLContext.set_info_callback? See Twisted for an example. - ssl_conn.set_tlsext_host_name(server_hostname.encode("idna")) - if self.verify_mode != _stdlibssl.CERT_NONE: - # Request a stapled OCSP response. - await loop.run_in_executor(None, ssl_conn.request_ocsp) - ssl_conn.set_connect_state() - # If this wasn't true the caller of wrap_socket would call - # do_handshake() - if do_handshake_on_connect: - # XXX: If we do hostname checking in a callback we can get rid - # of this call to do_handshake() since the handshake - # will happen automatically later. - await loop.run_in_executor(None, ssl_conn.do_handshake) - # XXX: Do this in a callback registered with - # SSLContext.set_info_callback? See Twisted for an example. - if self.check_hostname and server_hostname is not None: - from service_identity import pyopenssl - - try: - if _is_ip_address(server_hostname): - pyopenssl.verify_ip_address(ssl_conn, server_hostname) - else: - pyopenssl.verify_hostname(ssl_conn, server_hostname) - except ( # type:ignore[misc] - service_identity.SICertificateError, - service_identity.SIVerificationError, - ) as exc: - raise _CertificateError(str(exc)) from None - return ssl_conn - def wrap_socket( self, sock: _socket.socket, diff --git a/pymongo/read_concern.py b/pymongo/read_concern.py index fa2f4a318a..2adc403366 100644 --- a/pymongo/read_concern.py +++ b/pymongo/read_concern.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -38,7 +38,7 @@ def __init__(self, level: Optional[str] = None) -> None: if level is None or isinstance(level, str): self.__level = level else: - raise TypeError("level must be a string or None.") + raise TypeError(f"level must be a string or None, not {type(level)}") @property def level(self) -> Optional[str]: diff --git a/pymongo/read_preferences.py b/pymongo/read_preferences.py index 8c6e6de45d..dae414c37c 100644 --- a/pymongo/read_preferences.py +++ b/pymongo/read_preferences.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -19,6 +19,7 @@ from __future__ import annotations +import warnings from collections import abc from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence @@ -103,6 +104,11 @@ def _validate_hedge(hedge: Optional[_Hedge]) -> Optional[_Hedge]: if not isinstance(hedge, dict): raise TypeError(f"hedge must be a dictionary, not {hedge!r}") + warnings.warn( + "The read preference 'hedge' option is deprecated in PyMongo 4.12+ because hedged reads are deprecated in MongoDB version 8.0+. Support for 'hedge' will be removed in PyMongo 5.0.", + DeprecationWarning, + stacklevel=4, + ) return hedge @@ -183,7 +189,9 @@ def max_staleness(self) -> int: @property def hedge(self) -> Optional[_Hedge]: - """The read preference ``hedge`` parameter. + """**DEPRECATED** - The read preference 'hedge' option is deprecated in PyMongo 4.12+ because hedged reads are deprecated in MongoDB version 8.0+. Support for 'hedge' will be removed in PyMongo 5.0. + + The read preference ``hedge`` parameter. A dictionary that configures how the server will perform hedged reads. It consists of the following keys: @@ -203,6 +211,12 @@ def hedge(self) -> Optional[_Hedge]: .. versionadded:: 3.11 """ + if self.__hedge is not None: + warnings.warn( + "The read preference 'hedge' option is deprecated in PyMongo 4.12+ because hedged reads are deprecated in MongoDB version 8.0+. Support for 'hedge' will be removed in PyMongo 5.0.", + DeprecationWarning, + stacklevel=2, + ) return self.__hedge @property @@ -312,7 +326,7 @@ class PrimaryPreferred(_ServerMode): replication before it will no longer be selected for operations. Default -1, meaning no maximum. If it is set, it must be at least 90 seconds. - :param hedge: The :attr:`~hedge` to use if the primary is not available. + :param hedge: **DEPRECATED** - The :attr:`~hedge` for this read preference. .. versionchanged:: 3.11 Added ``hedge`` parameter. @@ -354,7 +368,7 @@ class Secondary(_ServerMode): replication before it will no longer be selected for operations. Default -1, meaning no maximum. If it is set, it must be at least 90 seconds. - :param hedge: The :attr:`~hedge` for this read preference. + :param hedge: **DEPRECATED** - The :attr:`~hedge` for this read preference. .. versionchanged:: 3.11 Added ``hedge`` parameter. @@ -397,7 +411,7 @@ class SecondaryPreferred(_ServerMode): replication before it will no longer be selected for operations. Default -1, meaning no maximum. If it is set, it must be at least 90 seconds. - :param hedge: The :attr:`~hedge` for this read preference. + :param hedge: **DEPRECATED** - The :attr:`~hedge` for this read preference. .. versionchanged:: 3.11 Added ``hedge`` parameter. @@ -441,7 +455,7 @@ class Nearest(_ServerMode): replication before it will no longer be selected for operations. Default -1, meaning no maximum. If it is set, it must be at least 90 seconds. - :param hedge: The :attr:`~hedge` for this read preference. + :param hedge: **DEPRECATED** - The :attr:`~hedge` for this read preference. .. versionchanged:: 3.11 Added ``hedge`` parameter. diff --git a/pymongo/response.py b/pymongo/response.py index e47749423f..211ddf2354 100644 --- a/pymongo/response.py +++ b/pymongo/response.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/results.py b/pymongo/results.py index d17ff1c3ea..bcce121fe7 100644 --- a/pymongo/results.py +++ b/pymongo/results.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/saslprep.py b/pymongo/saslprep.py index 7fb546f61b..9cef22419e 100644 --- a/pymongo/saslprep.py +++ b/pymongo/saslprep.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/server_api.py b/pymongo/server_api.py index 4a746008c4..40bb1aac3e 100644 --- a/pymongo/server_api.py +++ b/pymongo/server_api.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/server_description.py b/pymongo/server_description.py index 064ad43375..afc5346bb7 100644 --- a/pymongo/server_description.py +++ b/pymongo/server_description.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/server_selectors.py b/pymongo/server_selectors.py index c22ad599ee..0d1425ab31 100644 --- a/pymongo/server_selectors.py +++ b/pymongo/server_selectors.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/server_type.py b/pymongo/server_type.py index 937855cc7a..7a6d2aaf14 100644 --- a/pymongo/server_type.py +++ b/pymongo/server_type.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/ssl_context.py b/pymongo/ssl_context.py index ee32145c02..2ff7428cab 100644 --- a/pymongo/ssl_context.py +++ b/pymongo/ssl_context.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/ssl_support.py b/pymongo/ssl_support.py index 580d71f9b0..beafc717eb 100644 --- a/pymongo/ssl_support.py +++ b/pymongo/ssl_support.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -15,16 +15,19 @@ """Support for SSL in PyMongo.""" from __future__ import annotations +import types import warnings -from typing import Optional +from typing import Any, Optional, Union from pymongo.errors import ConfigurationError HAVE_SSL = True +HAVE_PYSSL = True try: - import pymongo.pyopenssl_context as _ssl + import pymongo.pyopenssl_context as _pyssl except (ImportError, AttributeError) as exc: + HAVE_PYSSL = False if isinstance(exc, AttributeError): warnings.warn( "Failed to use the installed version of PyOpenSSL. " @@ -35,10 +38,10 @@ UserWarning, stacklevel=2, ) - try: - import pymongo.ssl_context as _ssl # type: ignore[no-redef] - except ImportError: - HAVE_SSL = False +try: + import pymongo.ssl_context as _ssl +except ImportError: + HAVE_SSL = False if HAVE_SSL: @@ -49,14 +52,29 @@ import ssl as _stdlibssl # noqa: F401 from ssl import CERT_NONE, CERT_REQUIRED - HAS_SNI = _ssl.HAS_SNI IPADDR_SAFE = True + + if HAVE_PYSSL: + PYSSLError: Any = _pyssl.SSLError + BLOCKING_IO_ERRORS: tuple = _ssl.BLOCKING_IO_ERRORS + _pyssl.BLOCKING_IO_ERRORS + BLOCKING_IO_READ_ERROR: tuple = (_pyssl.BLOCKING_IO_READ_ERROR, _ssl.BLOCKING_IO_READ_ERROR) + BLOCKING_IO_WRITE_ERROR: tuple = ( + _pyssl.BLOCKING_IO_WRITE_ERROR, + _ssl.BLOCKING_IO_WRITE_ERROR, + ) + else: + PYSSLError = _ssl.SSLError + BLOCKING_IO_ERRORS = _ssl.BLOCKING_IO_ERRORS + BLOCKING_IO_READ_ERROR = (_ssl.BLOCKING_IO_READ_ERROR,) + BLOCKING_IO_WRITE_ERROR = (_ssl.BLOCKING_IO_WRITE_ERROR,) SSLError = _ssl.SSLError - BLOCKING_IO_ERRORS = _ssl.BLOCKING_IO_ERRORS - BLOCKING_IO_READ_ERROR = _ssl.BLOCKING_IO_READ_ERROR - BLOCKING_IO_WRITE_ERROR = _ssl.BLOCKING_IO_WRITE_ERROR BLOCKING_IO_LOOKUP_ERROR = BLOCKING_IO_READ_ERROR + def _has_sni(is_sync: bool) -> bool: + if is_sync and HAVE_PYSSL: + return _pyssl.HAS_SNI + return _ssl.HAS_SNI + def get_ssl_context( certfile: Optional[str], passphrase: Optional[str], @@ -65,10 +83,15 @@ def get_ssl_context( allow_invalid_certificates: bool, allow_invalid_hostnames: bool, disable_ocsp_endpoint_check: bool, - ) -> _ssl.SSLContext: + is_sync: bool, + ) -> Union[_pyssl.SSLContext, _ssl.SSLContext]: # type: ignore[name-defined] """Create and return an SSLContext object.""" + if is_sync and HAVE_PYSSL: + ssl: types.ModuleType = _pyssl + else: + ssl = _ssl verify_mode = CERT_NONE if allow_invalid_certificates else CERT_REQUIRED - ctx = _ssl.SSLContext(_ssl.PROTOCOL_SSLv23) + ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) if verify_mode != CERT_NONE: ctx.check_hostname = not allow_invalid_hostnames else: @@ -80,22 +103,20 @@ def get_ssl_context( # up to date versions of MongoDB 2.4 and above already disable # SSLv2 and SSLv3, python disables SSLv2 by default in >= 2.7.7 # and >= 3.3.4 and SSLv3 in >= 3.4.3. - ctx.options |= _ssl.OP_NO_SSLv2 - ctx.options |= _ssl.OP_NO_SSLv3 - ctx.options |= _ssl.OP_NO_COMPRESSION - ctx.options |= _ssl.OP_NO_RENEGOTIATION + ctx.options |= ssl.OP_NO_SSLv2 + ctx.options |= ssl.OP_NO_SSLv3 + ctx.options |= ssl.OP_NO_COMPRESSION + ctx.options |= ssl.OP_NO_RENEGOTIATION if certfile is not None: try: ctx.load_cert_chain(certfile, None, passphrase) - except _ssl.SSLError as exc: + except ssl.SSLError as exc: raise ConfigurationError(f"Private key doesn't match certificate: {exc}") from None if crlfile is not None: - if _ssl.IS_PYOPENSSL: + if ssl.IS_PYOPENSSL: raise ConfigurationError("tlsCRLFile cannot be used with PyOpenSSL") # Match the server's behavior. - ctx.verify_flags = getattr( # type:ignore[attr-defined] - _ssl, "VERIFY_CRL_CHECK_LEAF", 0 - ) + ctx.verify_flags = getattr(ssl, "VERIFY_CRL_CHECK_LEAF", 0) ctx.load_verify_locations(crlfile) if ca_certs is not None: ctx.load_verify_locations(ca_certs) @@ -109,10 +130,12 @@ def get_ssl_context( class SSLError(Exception): # type: ignore pass - HAS_SNI = False IPADDR_SAFE = False - BLOCKING_IO_ERRORS = () # type:ignore[assignment] + BLOCKING_IO_ERRORS = () + + def _has_sni(is_sync: bool) -> bool: # noqa: ARG001 + return False def get_ssl_context(*dummy): # type: ignore """No ssl module, raise ConfigurationError.""" - raise ConfigurationError("The ssl module is not available.") + raise ConfigurationError("The ssl module is not available") diff --git a/pymongo/synchronous/aggregation.py b/pymongo/synchronous/aggregation.py index 7c7e6252f7..3eb0c8bf54 100644 --- a/pymongo/synchronous/aggregation.py +++ b/pymongo/synchronous/aggregation.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/synchronous/auth.py b/pymongo/synchronous/auth.py index 56860eff3b..650e25234d 100644 --- a/pymongo/synchronous/auth.py +++ b/pymongo/synchronous/auth.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -158,7 +158,7 @@ def _password_digest(username: str, password: str) -> str: if len(password) == 0: raise ValueError("password can't be empty") if not isinstance(username, str): - raise TypeError("username must be an instance of str") + raise TypeError(f"username must be an instance of str, not {type(username)}") md5hash = hashlib.md5() # noqa: S324 data = f"{username}:mongo:{password}" diff --git a/pymongo/synchronous/auth_aws.py b/pymongo/synchronous/auth_aws.py index 7c0d24f3a1..c7ea47886f 100644 --- a/pymongo/synchronous/auth_aws.py +++ b/pymongo/synchronous/auth_aws.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/synchronous/auth_oidc.py b/pymongo/synchronous/auth_oidc.py index 5a8967d96b..8a8703c142 100644 --- a/pymongo/synchronous/auth_oidc.py +++ b/pymongo/synchronous/auth_oidc.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -213,7 +213,9 @@ def _get_access_token(self) -> Optional[str]: ) resp = cb.fetch(context) if not isinstance(resp, OIDCCallbackResult): - raise ValueError("Callback result must be of type OIDCCallbackResult") + raise ValueError( + f"Callback result must be of type OIDCCallbackResult, not {type(resp)}" + ) self.refresh_token = resp.refresh_token self.access_token = resp.access_token self.token_gen_id += 1 diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index 0b709f1acf..a528b09add 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -87,7 +87,7 @@ def __init__( self, collection: Collection[_DocumentType], ordered: bool, - bypass_document_validation: bool, + bypass_document_validation: Optional[bool], comment: Optional[str] = None, let: Optional[Any] = None, ) -> None: @@ -255,8 +255,8 @@ def write_command( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.STARTED, + clientId=client._topology_settings._topology_id, command=cmd, commandName=next(iter(cmd)), databaseName=bwc.db_name, @@ -276,8 +276,8 @@ def write_command( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.SUCCEEDED, + clientId=client._topology_settings._topology_id, durationMS=duration, reply=reply, commandName=next(iter(cmd)), @@ -302,8 +302,8 @@ def write_command( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.FAILED, + clientId=client._topology_settings._topology_id, durationMS=duration, failure=failure, commandName=next(iter(cmd)), @@ -340,8 +340,8 @@ def unack_write( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.STARTED, + clientId=client._topology_settings._topology_id, command=cmd, commandName=next(iter(cmd)), databaseName=bwc.db_name, @@ -366,8 +366,8 @@ def unack_write( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.SUCCEEDED, + clientId=client._topology_settings._topology_id, durationMS=duration, reply=reply, commandName=next(iter(cmd)), @@ -393,8 +393,8 @@ def unack_write( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.FAILED, + clientId=client._topology_settings._topology_id, durationMS=duration, failure=failure, commandName=next(iter(cmd)), @@ -516,8 +516,8 @@ def _execute_command( if self.comment: cmd["comment"] = self.comment _csot.apply_write_concern(cmd, write_concern) - if self.bypass_doc_val: - cmd["bypassDocumentValidation"] = True + if self.bypass_doc_val is not None: + cmd["bypassDocumentValidation"] = self.bypass_doc_val if self.let is not None and run.op_type in (_DELETE, _UPDATE): cmd["let"] = self.let if session: diff --git a/pymongo/synchronous/change_stream.py b/pymongo/synchronous/change_stream.py index a971ad08c0..304427b89b 100644 --- a/pymongo/synchronous/change_stream.py +++ b/pymongo/synchronous/change_stream.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -389,7 +389,8 @@ def try_next(self) -> Optional[_DocumentType]: if not _resumable(exc) and not exc.timeout: self.close() raise - except Exception: + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. + except BaseException: self.close() raise diff --git a/pymongo/synchronous/client_bulk.py b/pymongo/synchronous/client_bulk.py index 9f6e3f7cf0..d73bfb2a2b 100644 --- a/pymongo/synchronous/client_bulk.py +++ b/pymongo/synchronous/client_bulk.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -241,8 +241,8 @@ def write_command( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.STARTED, + clientId=client._topology_settings._topology_id, command=cmd, commandName=next(iter(cmd)), databaseName=bwc.db_name, @@ -262,8 +262,8 @@ def write_command( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.SUCCEEDED, + clientId=client._topology_settings._topology_id, durationMS=duration, reply=reply, commandName=next(iter(cmd)), @@ -289,8 +289,8 @@ def write_command( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.FAILED, + clientId=client._topology_settings._topology_id, durationMS=duration, failure=failure, commandName=next(iter(cmd)), @@ -330,8 +330,8 @@ def unack_write( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.STARTED, + clientId=client._topology_settings._topology_id, command=cmd, commandName=next(iter(cmd)), databaseName=bwc.db_name, @@ -356,8 +356,8 @@ def unack_write( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.SUCCEEDED, + clientId=client._topology_settings._topology_id, durationMS=duration, reply=reply, commandName=next(iter(cmd)), @@ -383,8 +383,8 @@ def unack_write( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.FAILED, + clientId=client._topology_settings._topology_id, durationMS=duration, failure=failure, commandName=next(iter(cmd)), diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index f1d680fc0a..aaf2d7574f 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -309,7 +309,9 @@ def __init__( ) if max_commit_time_ms is not None: if not isinstance(max_commit_time_ms, int): - raise TypeError("max_commit_time_ms must be an integer or None") + raise TypeError( + f"max_commit_time_ms must be an integer or None, not {type(max_commit_time_ms)}" + ) @property def read_concern(self) -> Optional[ReadConcern]: @@ -455,10 +457,10 @@ def _max_time_expired_error(exc: PyMongoError) -> bool: # From the transactions spec, all the retryable writes errors plus -# WriteConcernFailed. +# WriteConcernTimeout. _UNKNOWN_COMMIT_ERROR_CODES: frozenset = _RETRYABLE_ERROR_CODES | frozenset( [ - 64, # WriteConcernFailed + 64, # WriteConcernTimeout 50, # MaxTimeMSExpired ] ) @@ -692,7 +694,8 @@ def callback(session, custom_arg, custom_kwarg=None): self.start_transaction(read_concern, write_concern, read_preference, max_commit_time_ms) try: ret = callback(self) - except Exception as exc: + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. + except BaseException as exc: if self.in_transaction: self.abort_transaction() if ( @@ -897,7 +900,9 @@ def advance_cluster_time(self, cluster_time: Mapping[str, Any]) -> None: another `ClientSession` instance. """ if not isinstance(cluster_time, _Mapping): - raise TypeError("cluster_time must be a subclass of collections.Mapping") + raise TypeError( + f"cluster_time must be a subclass of collections.Mapping, not {type(cluster_time)}" + ) if not isinstance(cluster_time.get("clusterTime"), Timestamp): raise ValueError("Invalid cluster_time") self._advance_cluster_time(cluster_time) @@ -918,7 +923,9 @@ def advance_operation_time(self, operation_time: Timestamp) -> None: another `ClientSession` instance. """ if not isinstance(operation_time, Timestamp): - raise TypeError("operation_time must be an instance of bson.timestamp.Timestamp") + raise TypeError( + f"operation_time must be an instance of bson.timestamp.Timestamp, not {type(operation_time)}" + ) self._advance_operation_time(operation_time) def _process_response(self, reply: Mapping[str, Any]) -> None: diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index 6edfddc9a9..8a71768318 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -231,7 +231,7 @@ def __init__( read_concern or database.read_concern, ) if not isinstance(name, str): - raise TypeError("name must be an instance of str") + raise TypeError(f"name must be an instance of str, not {type(name)}") from pymongo.synchronous.database import Database if not isinstance(database, Database): @@ -700,7 +700,7 @@ def bulk_write( self, requests: Sequence[_WriteOp[_DocumentType]], ordered: bool = True, - bypass_document_validation: bool = False, + bypass_document_validation: Optional[bool] = None, session: Optional[ClientSession] = None, comment: Optional[Any] = None, let: Optional[Mapping] = None, @@ -799,7 +799,7 @@ def _insert_one( ordered: bool, write_concern: WriteConcern, op_id: Optional[int], - bypass_doc_val: bool, + bypass_doc_val: Optional[bool], session: Optional[ClientSession], comment: Optional[Any] = None, ) -> Any: @@ -813,8 +813,8 @@ def _insert_one( def _insert_command( session: Optional[ClientSession], conn: Connection, retryable_write: bool ) -> None: - if bypass_doc_val: - command["bypassDocumentValidation"] = True + if bypass_doc_val is not None: + command["bypassDocumentValidation"] = bypass_doc_val result = conn.command( self._database.name, @@ -839,7 +839,7 @@ def _insert_command( def insert_one( self, document: Union[_DocumentType, RawBSONDocument], - bypass_document_validation: bool = False, + bypass_document_validation: Optional[bool] = None, session: Optional[ClientSession] = None, comment: Optional[Any] = None, ) -> InsertOneResult: @@ -905,7 +905,7 @@ def insert_many( self, documents: Iterable[Union[_DocumentType, RawBSONDocument]], ordered: bool = True, - bypass_document_validation: bool = False, + bypass_document_validation: Optional[bool] = None, session: Optional[ClientSession] = None, comment: Optional[Any] = None, ) -> InsertManyResult: @@ -985,7 +985,7 @@ def _update( write_concern: Optional[WriteConcern] = None, op_id: Optional[int] = None, ordered: bool = True, - bypass_doc_val: Optional[bool] = False, + bypass_doc_val: Optional[bool] = None, collation: Optional[_CollationIn] = None, array_filters: Optional[Sequence[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None, @@ -1040,8 +1040,8 @@ def _update( if comment is not None: command["comment"] = comment # Update command. - if bypass_doc_val: - command["bypassDocumentValidation"] = True + if bypass_doc_val is not None: + command["bypassDocumentValidation"] = bypass_doc_val # The command result has to be published for APM unmodified # so we make a shallow copy here before adding updatedExisting. @@ -1081,7 +1081,7 @@ def _update_retryable( write_concern: Optional[WriteConcern] = None, op_id: Optional[int] = None, ordered: bool = True, - bypass_doc_val: Optional[bool] = False, + bypass_doc_val: Optional[bool] = None, collation: Optional[_CollationIn] = None, array_filters: Optional[Sequence[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None, @@ -1127,7 +1127,7 @@ def replace_one( filter: Mapping[str, Any], replacement: Mapping[str, Any], upsert: bool = False, - bypass_document_validation: bool = False, + bypass_document_validation: Optional[bool] = None, collation: Optional[_CollationIn] = None, hint: Optional[_IndexKeyHint] = None, session: Optional[ClientSession] = None, @@ -1236,7 +1236,7 @@ def update_one( filter: Mapping[str, Any], update: Union[Mapping[str, Any], _Pipeline], upsert: bool = False, - bypass_document_validation: bool = False, + bypass_document_validation: Optional[bool] = None, collation: Optional[_CollationIn] = None, array_filters: Optional[Sequence[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None, @@ -2472,7 +2472,7 @@ def _drop_index( name = helpers_shared._gen_index_name(index_or_name) if not isinstance(name, str): - raise TypeError("index_or_name must be an instance of str or list") + raise TypeError(f"index_or_name must be an instance of str or list, not {type(name)}") cmd = {"dropIndexes": self._name, "index": name} cmd.update(kwargs) @@ -2941,6 +2941,7 @@ def aggregate( returning aggregate results using a cursor. - `collation` (optional): An instance of :class:`~pymongo.collation.Collation`. + - `bypassDocumentValidation` (bool): If ``True``, allows the write to opt-out of document level validation. :return: A :class:`~pymongo.command_cursor.CommandCursor` over the result @@ -3071,7 +3072,7 @@ def rename( """ if not isinstance(new_name, str): - raise TypeError("new_name must be an instance of str") + raise TypeError(f"new_name must be an instance of str, not {type(new_name)}") if not new_name or ".." in new_name: raise InvalidName("collection names cannot be empty") @@ -3104,6 +3105,7 @@ def distinct( filter: Optional[Mapping[str, Any]] = None, session: Optional[ClientSession] = None, comment: Optional[Any] = None, + hint: Optional[_IndexKeyHint] = None, **kwargs: Any, ) -> list: """Get a list of distinct values for `key` among all documents @@ -3131,8 +3133,15 @@ def distinct( :class:`~pymongo.client_session.ClientSession`. :param comment: A user-provided comment to attach to this command. + :param hint: An index to use to support the query + predicate specified either by its string name, or in the same + format as passed to :meth:`~pymongo.collection.Collection.create_index` + (e.g. ``[('field', ASCENDING)]``). :param kwargs: See list of options above. + .. versionchanged:: 4.12 + Added ``hint`` parameter. + .. versionchanged:: 3.6 Added ``session`` parameter. @@ -3141,7 +3150,7 @@ def distinct( """ if not isinstance(key, str): - raise TypeError("key must be an instance of str") + raise TypeError(f"key must be an instance of str, not {type(key)}") cmd = {"distinct": self._name, "key": key} if filter is not None: if "query" in kwargs: @@ -3151,6 +3160,10 @@ def distinct( cmd.update(kwargs) if comment is not None: cmd["comment"] = comment + if hint is not None: + if not isinstance(hint, str): + hint = helpers_shared._index_document(hint) + cmd["hint"] = hint # type: ignore[assignment] def _cmd( session: Optional[ClientSession], @@ -3189,7 +3202,7 @@ def _find_and_modify( common.validate_is_mapping("filter", filter) if not isinstance(return_document, bool): raise ValueError( - "return_document must be ReturnDocument.BEFORE or ReturnDocument.AFTER" + f"return_document must be ReturnDocument.BEFORE or ReturnDocument.AFTER, not {type(return_document)}" ) collation = validate_collation_or_none(kwargs.pop("collation", None)) cmd = {"findAndModify": self._name, "query": filter, "new": return_document} diff --git a/pymongo/synchronous/command_cursor.py b/pymongo/synchronous/command_cursor.py index 3a4372856a..e23519d740 100644 --- a/pymongo/synchronous/command_cursor.py +++ b/pymongo/synchronous/command_cursor.py @@ -94,7 +94,9 @@ def __init__( self.batch_size(batch_size) if not isinstance(max_await_time_ms, int) and max_await_time_ms is not None: - raise TypeError("max_await_time_ms must be an integer or None") + raise TypeError( + f"max_await_time_ms must be an integer or None, not {type(max_await_time_ms)}" + ) def __del__(self) -> None: self._die_no_lock() @@ -115,7 +117,7 @@ def batch_size(self, batch_size: int) -> CommandCursor[_DocumentType]: :param batch_size: The size of each batch of results requested. """ if not isinstance(batch_size, int): - raise TypeError("batch_size must be an integer") + raise TypeError(f"batch_size must be an integer, not {type(batch_size)}") if batch_size < 0: raise ValueError("batch_size must be >= 0") diff --git a/pymongo/synchronous/cursor.py b/pymongo/synchronous/cursor.py index b35098a327..31c4604f89 100644 --- a/pymongo/synchronous/cursor.py +++ b/pymongo/synchronous/cursor.py @@ -146,9 +146,9 @@ def __init__( spec: Mapping[str, Any] = filter or {} validate_is_mapping("filter", spec) if not isinstance(skip, int): - raise TypeError("skip must be an instance of int") + raise TypeError(f"skip must be an instance of int, not {type(skip)}") if not isinstance(limit, int): - raise TypeError("limit must be an instance of int") + raise TypeError(f"limit must be an instance of int, not {type(limit)}") validate_boolean("no_cursor_timeout", no_cursor_timeout) if no_cursor_timeout and not self._explicit_session: warnings.warn( @@ -171,7 +171,7 @@ def __init__( validate_boolean("allow_partial_results", allow_partial_results) validate_boolean("oplog_replay", oplog_replay) if not isinstance(batch_size, int): - raise TypeError("batch_size must be an integer") + raise TypeError(f"batch_size must be an integer, not {type(batch_size)}") if batch_size < 0: raise ValueError("batch_size must be >= 0") # Only set if allow_disk_use is provided by the user, else None. @@ -388,7 +388,7 @@ def add_option(self, mask: int) -> Cursor[_DocumentType]: cursor.add_option(2) """ if not isinstance(mask, int): - raise TypeError("mask must be an int") + raise TypeError(f"mask must be an int, not {type(mask)}") self._check_okay_to_chain() if mask & _QUERY_OPTIONS["exhaust"]: @@ -408,7 +408,7 @@ def remove_option(self, mask: int) -> Cursor[_DocumentType]: cursor.remove_option(2) """ if not isinstance(mask, int): - raise TypeError("mask must be an int") + raise TypeError(f"mask must be an int, not {type(mask)}") self._check_okay_to_chain() if mask & _QUERY_OPTIONS["exhaust"]: @@ -432,7 +432,7 @@ def allow_disk_use(self, allow_disk_use: bool) -> Cursor[_DocumentType]: .. versionadded:: 3.11 """ if not isinstance(allow_disk_use, bool): - raise TypeError("allow_disk_use must be a bool") + raise TypeError(f"allow_disk_use must be a bool, not {type(allow_disk_use)}") self._check_okay_to_chain() self._allow_disk_use = allow_disk_use @@ -451,7 +451,7 @@ def limit(self, limit: int) -> Cursor[_DocumentType]: .. seealso:: The MongoDB documentation on `limit `_. """ if not isinstance(limit, int): - raise TypeError("limit must be an integer") + raise TypeError(f"limit must be an integer, not {type(limit)}") if self._exhaust: raise InvalidOperation("Can't use limit and exhaust together.") self._check_okay_to_chain() @@ -479,7 +479,7 @@ def batch_size(self, batch_size: int) -> Cursor[_DocumentType]: :param batch_size: The size of each batch of results requested. """ if not isinstance(batch_size, int): - raise TypeError("batch_size must be an integer") + raise TypeError(f"batch_size must be an integer, not {type(batch_size)}") if batch_size < 0: raise ValueError("batch_size must be >= 0") self._check_okay_to_chain() @@ -499,7 +499,7 @@ def skip(self, skip: int) -> Cursor[_DocumentType]: :param skip: the number of results to skip """ if not isinstance(skip, int): - raise TypeError("skip must be an integer") + raise TypeError(f"skip must be an integer, not {type(skip)}") if skip < 0: raise ValueError("skip must be >= 0") self._check_okay_to_chain() @@ -520,7 +520,7 @@ def max_time_ms(self, max_time_ms: Optional[int]) -> Cursor[_DocumentType]: :param max_time_ms: the time limit after which the operation is aborted """ if not isinstance(max_time_ms, int) and max_time_ms is not None: - raise TypeError("max_time_ms must be an integer or None") + raise TypeError(f"max_time_ms must be an integer or None, not {type(max_time_ms)}") self._check_okay_to_chain() self._max_time_ms = max_time_ms @@ -543,7 +543,9 @@ def max_await_time_ms(self, max_await_time_ms: Optional[int]) -> Cursor[_Documen .. versionadded:: 3.2 """ if not isinstance(max_await_time_ms, int) and max_await_time_ms is not None: - raise TypeError("max_await_time_ms must be an integer or None") + raise TypeError( + f"max_await_time_ms must be an integer or None, not {type(max_await_time_ms)}" + ) self._check_okay_to_chain() # Ignore max_await_time_ms if not tailable or await_data is False. @@ -677,7 +679,7 @@ def max(self, spec: _Sort) -> Cursor[_DocumentType]: .. versionadded:: 2.7 """ if not isinstance(spec, (list, tuple)): - raise TypeError("spec must be an instance of list or tuple") + raise TypeError(f"spec must be an instance of list or tuple, not {type(spec)}") self._check_okay_to_chain() self._max = dict(spec) @@ -699,7 +701,7 @@ def min(self, spec: _Sort) -> Cursor[_DocumentType]: .. versionadded:: 2.7 """ if not isinstance(spec, (list, tuple)): - raise TypeError("spec must be an instance of list or tuple") + raise TypeError(f"spec must be an instance of list or tuple, not {type(spec)}") self._check_okay_to_chain() self._min = dict(spec) @@ -1122,7 +1124,8 @@ def _send_message(self, operation: Union[_Query, _GetMore]) -> None: self._killed = True self.close() raise - except Exception: + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. + except BaseException: self.close() raise diff --git a/pymongo/synchronous/database.py b/pymongo/synchronous/database.py index a0bef55343..a11674b9aa 100644 --- a/pymongo/synchronous/database.py +++ b/pymongo/synchronous/database.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -122,7 +122,7 @@ def __init__( from pymongo.synchronous.mongo_client import MongoClient if not isinstance(name, str): - raise TypeError("name must be an instance of str") + raise TypeError(f"name must be an instance of str, not {type(name)}") if not isinstance(client, MongoClient): # This is for compatibility with mocked and subclassed types, such as in Motor. @@ -1303,7 +1303,7 @@ def drop_collection( name = name.name if not isinstance(name, str): - raise TypeError("name_or_collection must be an instance of str") + raise TypeError(f"name_or_collection must be an instance of str, not {type(name)}") encrypted_fields = self._get_encrypted_fields( {"encryptedFields": encrypted_fields}, name, @@ -1367,7 +1367,9 @@ def validate_collection( name = name.name if not isinstance(name, str): - raise TypeError("name_or_collection must be an instance of str or Collection") + raise TypeError( + f"name_or_collection must be an instance of str or Collection, not {type(name)}" + ) cmd = {"validate": name, "scandata": scandata, "full": full} if comment is not None: cmd["comment"] = comment diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index d41169861f..5f9bdac4b7 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -15,7 +15,6 @@ """Support for explicit client-side field level encryption.""" from __future__ import annotations -import asyncio import contextlib import enum import socket @@ -71,23 +70,23 @@ NetworkTimeout, ServerSelectionTimeoutError, ) -from pymongo.network_layer import BLOCKING_IO_ERRORS, sendall +from pymongo.network_layer import sendall from pymongo.operations import UpdateOne from pymongo.pool_options import PoolOptions +from pymongo.pool_shared import ( + _configured_socket, + _get_timeout_details, + _raise_connection_failure, +) from pymongo.read_concern import ReadConcern from pymongo.results import BulkWriteResult, DeleteResult -from pymongo.ssl_support import get_ssl_context +from pymongo.ssl_support import BLOCKING_IO_ERRORS, get_ssl_context from pymongo.synchronous.collection import Collection from pymongo.synchronous.cursor import Cursor from pymongo.synchronous.database import Database from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.pool import ( - _configured_socket, - _get_timeout_details, - _raise_connection_failure, -) from pymongo.typings import _DocumentType, _DocumentTypeArg -from pymongo.uri_parser import parse_host +from pymongo.uri_parser_shared import _parse_kms_tls_options, parse_host from pymongo.write_concern import WriteConcern if TYPE_CHECKING: @@ -127,8 +126,6 @@ def _wrap_encryption_errors() -> Iterator[None]: # BSON encoding/decoding errors are unrelated to encryption so # we should propagate them unchanged. raise - except asyncio.CancelledError: - raise except Exception as exc: raise EncryptionError(exc) from exc @@ -159,6 +156,7 @@ def __init__( self.mongocryptd_client = mongocryptd_client self.opts = opts self._spawned = False + self._kms_ssl_contexts = opts._kms_ssl_contexts(_IS_SYNC) def kms_request(self, kms_context: MongoCryptKmsContext) -> None: """Complete a KMS request. @@ -170,7 +168,7 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None: endpoint = kms_context.endpoint message = kms_context.message provider = kms_context.kms_provider - ctx = self.opts._kms_ssl_contexts.get(provider) + ctx = self._kms_ssl_contexts.get(provider) if ctx is None: # Enable strict certificate verification, OCSP, match hostname, and # SNI using the system default CA certificates. @@ -182,6 +180,7 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None: False, # allow_invalid_certificates False, # allow_invalid_hostnames False, # disable_ocsp_endpoint_check + _IS_SYNC, ) # CSOT: set timeout for socket creation. connect_timeout = max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0.001) @@ -244,7 +243,7 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None: ) raise exc from final_err - def collection_info(self, database: str, filter: bytes) -> Optional[bytes]: + def collection_info(self, database: str, filter: bytes) -> Optional[list[bytes]]: """Get the collection info for a namespace. The returned collection info is passed to libmongocrypt which reads @@ -253,12 +252,10 @@ def collection_info(self, database: str, filter: bytes) -> Optional[bytes]: :param database: The database on which to run listCollections. :param filter: The filter to pass to listCollections. - :return: The first document from the listCollections command response as BSON. + :return: All documents from the listCollections command response as BSON. """ with self.client_ref()[database].list_collections(filter=RawBSONDocument(filter)) as cursor: - for doc in cursor: - return _dict_to_bson(doc, False, _DATA_KEY_OPTS) - return None + return [_dict_to_bson(doc, False, _DATA_KEY_OPTS) for doc in cursor] def spawn(self) -> None: """Spawn mongocryptd. @@ -320,7 +317,9 @@ def insert_data_key(self, data_key: bytes) -> Binary: raw_doc = RawBSONDocument(data_key, _KEY_VAULT_OPTS) data_key_id = raw_doc.get("_id") if not isinstance(data_key_id, Binary) or data_key_id.subtype != UUID_SUBTYPE: - raise TypeError("data_key _id must be Binary with a UUID subtype") + raise TypeError( + f"data_key _id must be Binary with a UUID subtype, not {type(data_key_id)}" + ) assert self.key_vault_coll is not None self.key_vault_coll.insert_one(raw_doc) @@ -396,6 +395,8 @@ def __init__(self, client: MongoClient[_DocumentTypeArg], opts: AutoEncryptionOp encrypted_fields_map = _dict_to_bson(opts._encrypted_fields_map, False, _DATA_KEY_OPTS) self._bypass_auto_encryption = opts._bypass_auto_encryption self._internal_client = None + # parsing kms_ssl_contexts here so that parsing errors will be raised before internal clients are created + opts._kms_ssl_contexts(_IS_SYNC) def _get_internal_client( encrypter: _Encrypter, mongo_client: MongoClient[_DocumentTypeArg] @@ -443,6 +444,7 @@ def _get_internal_client( bypass_encryption=opts._bypass_auto_encryption, encrypted_fields_map=encrypted_fields_map, bypass_query_analysis=opts._bypass_query_analysis, + key_expiration_ms=opts._key_expiration_ms, ), ) self._closed = False @@ -545,11 +547,10 @@ class QueryType(str, enum.Enum): def _create_mongocrypt_options(**kwargs: Any) -> MongoCryptOptions: - opts = MongoCryptOptions(**kwargs) - # Opt into range V2 encryption. - if hasattr(opts, "enable_range_v2"): - opts.enable_range_v2 = True - return opts + # For compat with pymongocrypt <1.13, avoid setting the default key_expiration_ms. + if kwargs.get("key_expiration_ms") is None: + kwargs.pop("key_expiration_ms", None) + return MongoCryptOptions(**kwargs, enable_multiple_collinfo=True) class ClientEncryption(Generic[_DocumentType]): @@ -562,6 +563,7 @@ def __init__( key_vault_client: MongoClient[_DocumentTypeArg], codec_options: CodecOptions[_DocumentTypeArg], kms_tls_options: Optional[Mapping[str, Any]] = None, + key_expiration_ms: Optional[int] = None, ) -> None: """Explicit client-side field level encryption. @@ -628,7 +630,12 @@ def __init__( Or to supply a client certificate:: kms_tls_options={'kmip': {'tlsCertificateKeyFile': 'client.pem'}} + :param key_expiration_ms: The cache expiration time for data encryption keys. + Defaults to ``None`` which defers to libmongocrypt's default which is currently 60000. + Set to 0 to disable key expiration. + .. versionchanged:: 4.12 + Added the `key_expiration_ms` parameter. .. versionchanged:: 4.0 Added the `kms_tls_options` parameter and the "kmip" KMS provider. @@ -642,7 +649,9 @@ def __init__( ) if not isinstance(codec_options, CodecOptions): - raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions") + raise TypeError( + f"codec_options must be an instance of bson.codec_options.CodecOptions, not {type(codec_options)}" + ) if not isinstance(key_vault_client, MongoClient): # This is for compatibility with mocked and subclassed types, such as in Motor. @@ -658,14 +667,20 @@ def __init__( key_vault_coll = key_vault_client[db][coll] opts = AutoEncryptionOpts( - kms_providers, key_vault_namespace, kms_tls_options=kms_tls_options + kms_providers, + key_vault_namespace, + kms_tls_options=kms_tls_options, + key_expiration_ms=key_expiration_ms, ) + self._kms_ssl_contexts = _parse_kms_tls_options(opts._kms_tls_options, _IS_SYNC) self._io_callbacks: Optional[_EncryptionIO] = _EncryptionIO( None, key_vault_coll, None, opts ) self._encryption = ExplicitEncrypter( self._io_callbacks, - _create_mongocrypt_options(kms_providers=kms_providers, schema_map=None), + _create_mongocrypt_options( + kms_providers=kms_providers, schema_map=None, key_expiration_ms=key_expiration_ms + ), ) # Use the same key vault collection as the callback. assert self._io_callbacks.key_vault_coll is not None @@ -692,6 +707,7 @@ def create_encrypted_collection( creation. :class:`~pymongo.errors.EncryptionError` will be raised if the collection already exists. + :param database: the database to create the collection :param name: the name of the collection to create :param encrypted_fields: Document that describes the encrypted fields for Queryable Encryption. The "keyId" may be set to ``None`` to auto-generate the data keys. For example: @@ -756,8 +772,6 @@ def create_encrypted_collection( database.create_collection(name=name, **kwargs), encrypted_fields, ) - except asyncio.CancelledError: - raise except Exception as exc: raise EncryptedCollectionError(exc, encrypted_fields) from exc diff --git a/pymongo/synchronous/helpers.py b/pymongo/synchronous/helpers.py index f800e7dcc8..bc69a49e80 100644 --- a/pymongo/synchronous/helpers.py +++ b/pymongo/synchronous/helpers.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index a694a58c1e..99a517e5c1 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -42,6 +42,7 @@ TYPE_CHECKING, Any, Callable, + Collection, ContextManager, FrozenSet, Generator, @@ -59,7 +60,7 @@ from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry from bson.timestamp import Timestamp -from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser +from pymongo import _csot, common, helpers_shared, periodic_executor from pymongo.client_options import ClientOptions from pymongo.errors import ( AutoReconnect, @@ -80,9 +81,15 @@ _create_lock, _release_locks, ) -from pymongo.logger import _CLIENT_LOGGER, _log_or_warn +from pymongo.logger import ( + _CLIENT_LOGGER, + _COMMAND_LOGGER, + _debug_log, + _log_client_error, + _log_or_warn, +) from pymongo.message import _CursorAddress, _GetMore, _Query -from pymongo.monitoring import ConnectionClosedReason +from pymongo.monitoring import ConnectionClosedReason, _EventListeners from pymongo.operations import ( DeleteMany, DeleteOne, @@ -94,9 +101,10 @@ ) from pymongo.read_preferences import ReadPreference, _ServerMode from pymongo.results import ClientBulkWriteResult +from pymongo.server_description import ServerDescription from pymongo.server_selectors import writable_server_selector from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous import client_session, database +from pymongo.synchronous import client_session, database, uri_parser from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream from pymongo.synchronous.client_bulk import _ClientBulk from pymongo.synchronous.client_session import _EmptyServerSession @@ -112,11 +120,14 @@ _DocumentTypeArg, _Pipeline, ) -from pymongo.uri_parser import ( +from pymongo.uri_parser_shared import ( + SRV_SCHEME, _check_options, _handle_option_deprecations, _handle_security_options, _normalize_options, + _validate_uri, + split_hosts, ) from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern @@ -130,6 +141,7 @@ from pymongo.synchronous.bulk import _Bulk from pymongo.synchronous.client_session import ClientSession, _ServerSession from pymongo.synchronous.cursor import _ConnectionManager + from pymongo.synchronous.encryption import _Encrypter from pymongo.synchronous.pool import Connection from pymongo.synchronous.server import Server @@ -187,7 +199,7 @@ def __init__( execute. The `host` parameter can be a full `mongodb URI - `_, in addition to + `_, in addition to a simple hostname. It can also be a list of hostnames but no more than one URI. Any port specified in the host string(s) will override the `port` parameter. For username and @@ -274,7 +286,9 @@ def __init__( :param type_registry: instance of :class:`~bson.codec_options.TypeRegistry` to enable encoding and decoding of custom types. - :param datetime_conversion: Specifies how UTC datetimes should be decoded + :param kwargs: **Additional optional parameters available as keyword arguments:** + + - `datetime_conversion` (optional): Specifies how UTC datetimes should be decoded within BSON. Valid options include 'datetime_ms' to return as a DatetimeMS, 'datetime' to return as a datetime.datetime and raising a ValueError for out-of-range values, 'datetime_auto' to @@ -282,9 +296,6 @@ def __init__( out-of-range and 'datetime_clamp' to clamp to the minimum and maximum possible datetimes. Defaults to 'datetime'. See :ref:`handling-out-of-range-datetimes` for details. - - | **Other optional parameters can be passed as keyword arguments:** - - `directConnection` (optional): if ``True``, forces this client to connect directly to the specified MongoDB host as a standalone. If ``false``, the client connects to the entire replica set of @@ -748,7 +759,13 @@ def __init__( if port is None: port = self.PORT if not isinstance(port, int): - raise TypeError("port must be an instance of int") + raise TypeError(f"port must be an instance of int, not {type(port)}") + self._host = host + self._port = port + self._topology: Topology = None # type: ignore[assignment] + self._timeout: float | None = None + self._topology_settings: TopologySettings = None # type: ignore[assignment] + self._event_listeners: _EventListeners | None = None # _pool_class, _monitor_class, and _condition_class are for deep # customization of PyMongo, e.g. Motor. @@ -759,8 +776,10 @@ def __init__( # Parse options passed as kwargs. keyword_opts = common._CaseInsensitiveDictionary(kwargs) keyword_opts["document_class"] = doc_class + self._resolve_srv_info: dict[str, Any] = {"keyword_opts": keyword_opts} - seeds = set() + self._seeds = set() + is_srv = False username = None password = None dbase = None @@ -768,41 +787,34 @@ def __init__( fqdn = None srv_service_name = keyword_opts.get("srvservicename") srv_max_hosts = keyword_opts.get("srvmaxhosts") - if len([h for h in host if "/" in h]) > 1: + if len([h for h in self._host if "/" in h]) > 1: raise ConfigurationError("host must not contain multiple MongoDB URIs") - for entity in host: + for entity in self._host: # A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/' # it must be a URI, # https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names if "/" in entity: - # Determine connection timeout from kwargs. - timeout = keyword_opts.get("connecttimeoutms") - if timeout is not None: - timeout = common.validate_timeout_or_none_or_zero( - keyword_opts.cased_key("connecttimeoutms"), timeout - ) - res = uri_parser.parse_uri( + res = _validate_uri( entity, port, validate=True, warn=True, normalize=False, - connect_timeout=timeout, - srv_service_name=srv_service_name, srv_max_hosts=srv_max_hosts, ) - seeds.update(res["nodelist"]) + is_srv = entity.startswith(SRV_SCHEME) + self._seeds.update(res["nodelist"]) username = res["username"] or username password = res["password"] or password dbase = res["database"] or dbase opts = res["options"] fqdn = res["fqdn"] else: - seeds.update(uri_parser.split_hosts(entity, port)) - if not seeds: + self._seeds.update(split_hosts(entity, self._port)) + if not self._seeds: raise ConfigurationError("need to specify at least one host") - for hostname in [node[0] for node in seeds]: + for hostname in [node[0] for node in self._seeds]: if _detect_external_db(hostname): break @@ -819,80 +831,180 @@ def __init__( keyword_opts["tz_aware"] = tz_aware keyword_opts["connect"] = connect - # Handle deprecated options in kwarg options. - keyword_opts = _handle_option_deprecations(keyword_opts) - # Validate kwarg options. - keyword_opts = common._CaseInsensitiveDictionary( - dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items()) - ) - - # Override connection string options with kwarg options. - opts.update(keyword_opts) + opts = self._validate_kwargs_and_update_opts(keyword_opts, opts) if srv_service_name is None: srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME) srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts") - # Handle security-option conflicts in combined options. - opts = _handle_security_options(opts) - # Normalize combined options. - opts = _normalize_options(opts) - _check_options(seeds, opts) + opts = self._normalize_and_validate_options(opts, self._seeds) # Username and password passed as kwargs override user info in URI. username = opts.get("username", username) password = opts.get("password", password) - self._options = options = ClientOptions(username, password, dbase, opts, _IS_SYNC) + self._options = ClientOptions(username, password, dbase, opts, _IS_SYNC) self._default_database_name = dbase self._lock = _create_lock() self._kill_cursors_queue: list = [] - self._event_listeners = options.pool_options._event_listeners - super().__init__( - options.codec_options, - options.read_preference, - options.write_concern, - options.read_concern, + self._encrypter: Optional[_Encrypter] = None + + self._resolve_srv_info.update( + { + "is_srv": is_srv, + "username": username, + "password": password, + "dbase": dbase, + "seeds": self._seeds, + "fqdn": fqdn, + "srv_service_name": srv_service_name, + "pool_class": pool_class, + "monitor_class": monitor_class, + "condition_class": condition_class, + } ) - self._topology_settings = TopologySettings( - seeds=seeds, - replica_set_name=options.replica_set_name, - pool_class=pool_class, - pool_options=options.pool_options, - monitor_class=monitor_class, - condition_class=condition_class, - local_threshold_ms=options.local_threshold_ms, - server_selection_timeout=options.server_selection_timeout, - server_selector=options.server_selector, - heartbeat_frequency=options.heartbeat_frequency, - fqdn=fqdn, - direct_connection=options.direct_connection, - load_balanced=options.load_balanced, - srv_service_name=srv_service_name, - srv_max_hosts=srv_max_hosts, - server_monitoring_mode=options.server_monitoring_mode, + super().__init__( + self._options.codec_options, + self._options.read_preference, + self._options.write_concern, + self._options.read_concern, ) + self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name) + self._opened = False self._closed = False - self._init_background() + self._loop: Optional[asyncio.AbstractEventLoop] = None + if not is_srv: + self._init_background() if _IS_SYNC and connect: self._get_topology() # type: ignore[unused-coroutine] - self._encrypter = None + def _resolve_srv(self) -> None: + keyword_opts = self._resolve_srv_info["keyword_opts"] + seeds = set() + opts = common._CaseInsensitiveDictionary() + srv_service_name = keyword_opts.get("srvservicename") + srv_max_hosts = keyword_opts.get("srvmaxhosts") + for entity in self._host: + # A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/' + # it must be a URI, + # https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names + if "/" in entity: + # Determine connection timeout from kwargs. + timeout = keyword_opts.get("connecttimeoutms") + if timeout is not None: + timeout = common.validate_timeout_or_none_or_zero( + keyword_opts.cased_key("connecttimeoutms"), timeout + ) + res = uri_parser._parse_srv( + entity, + self._port, + validate=True, + warn=True, + normalize=False, + connect_timeout=timeout, + srv_service_name=srv_service_name, + srv_max_hosts=srv_max_hosts, + ) + seeds.update(res["nodelist"]) + opts = res["options"] + else: + seeds.update(split_hosts(entity, self._port)) + + if not seeds: + raise ConfigurationError("need to specify at least one host") + + for hostname in [node[0] for node in seeds]: + if _detect_external_db(hostname): + break + + # Add options with named keyword arguments to the parsed kwarg options. + tz_aware = keyword_opts["tz_aware"] + connect = keyword_opts["connect"] + if tz_aware is None: + tz_aware = opts.get("tz_aware", False) + if connect is None: + # Default to connect=True unless on a FaaS system, which might use fork. + from pymongo.pool_options import _is_faas + + connect = opts.get("connect", not _is_faas()) + keyword_opts["tz_aware"] = tz_aware + keyword_opts["connect"] = connect + + opts = self._validate_kwargs_and_update_opts(keyword_opts, opts) + + if srv_service_name is None: + srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME) + + srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts") + opts = self._normalize_and_validate_options(opts, seeds) + + # Username and password passed as kwargs override user info in URI. + username = opts.get("username", self._resolve_srv_info["username"]) + password = opts.get("password", self._resolve_srv_info["password"]) + self._options = ClientOptions( + username, password, self._resolve_srv_info["dbase"], opts, _IS_SYNC + ) + + self._init_based_on_options(seeds, srv_max_hosts, srv_service_name) + + def _init_based_on_options( + self, seeds: Collection[tuple[str, int]], srv_max_hosts: Any, srv_service_name: Any + ) -> None: + self._event_listeners = self._options.pool_options._event_listeners + self._topology_settings = TopologySettings( + seeds=seeds, + replica_set_name=self._options.replica_set_name, + pool_class=self._resolve_srv_info["pool_class"], + pool_options=self._options.pool_options, + monitor_class=self._resolve_srv_info["monitor_class"], + condition_class=self._resolve_srv_info["condition_class"], + local_threshold_ms=self._options.local_threshold_ms, + server_selection_timeout=self._options.server_selection_timeout, + server_selector=self._options.server_selector, + heartbeat_frequency=self._options.heartbeat_frequency, + fqdn=self._resolve_srv_info["fqdn"], + direct_connection=self._options.direct_connection, + load_balanced=self._options.load_balanced, + srv_service_name=srv_service_name, + srv_max_hosts=srv_max_hosts, + server_monitoring_mode=self._options.server_monitoring_mode, + topology_id=self._topology_settings._topology_id if self._topology_settings else None, + ) if self._options.auto_encryption_opts: from pymongo.synchronous.encryption import _Encrypter self._encrypter = _Encrypter(self, self._options.auto_encryption_opts) self._timeout = self._options.timeout - if _HAS_REGISTER_AT_FORK: - # Add this client to the list of weakly referenced items. - # This will be used later if we fork. - MongoClient._clients[self._topology._topology_id] = self + def _normalize_and_validate_options( + self, opts: common._CaseInsensitiveDictionary, seeds: set[tuple[str, int | None]] + ) -> common._CaseInsensitiveDictionary: + # Handle security-option conflicts in combined options. + opts = _handle_security_options(opts) + # Normalize combined options. + opts = _normalize_options(opts) + _check_options(seeds, opts) + return opts + + def _validate_kwargs_and_update_opts( + self, + keyword_opts: common._CaseInsensitiveDictionary, + opts: common._CaseInsensitiveDictionary, + ) -> common._CaseInsensitiveDictionary: + # Handle deprecated options in kwarg options. + keyword_opts = _handle_option_deprecations(keyword_opts) + # Validate kwarg options. + keyword_opts = common._CaseInsensitiveDictionary( + dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items()) + ) + # Override connection string options with kwarg options. + opts.update(keyword_opts) + return opts def _connect(self) -> None: """Explicitly connect to MongoDB synchronously instead of on the first operation.""" @@ -900,6 +1012,10 @@ def _connect(self) -> None: def _init_background(self, old_pid: Optional[int] = None) -> None: self._topology = Topology(self._topology_settings) + if _HAS_REGISTER_AT_FORK: + # Add this client to the list of weakly referenced items. + # This will be used later if we fork. + MongoClient._clients[self._topology._topology_id] = self # Seed the topology with the old one's pid so we can detect clients # that are opened before a fork and used after. self._topology._pid = old_pid @@ -1088,6 +1204,16 @@ def topology_description(self) -> TopologyDescription: .. versionadded:: 4.0 """ + if self._topology is None: + servers = {(host, port): ServerDescription((host, port)) for host, port in self._seeds} + return TopologyDescription( + TOPOLOGY_TYPE.Unknown, + servers, + None, + None, + None, + self._topology_settings, + ) return self._topology.description @property @@ -1101,6 +1227,8 @@ def nodes(self) -> FrozenSet[_Address]: to any servers, or a network partition causes it to lose connection to all servers. """ + if self._topology is None: + return frozenset() description = self._topology.description return frozenset(s.address for s in description.known_servers) @@ -1114,16 +1242,24 @@ def options(self) -> ClientOptions: """ return self._options + def eq_props(self) -> tuple[tuple[_Address, ...], Optional[str], Optional[str], str]: + return ( + tuple(sorted(self._resolve_srv_info["seeds"])), + self._options.replica_set_name, + self._resolve_srv_info["fqdn"], + self._resolve_srv_info["srv_service_name"], + ) + def __eq__(self, other: Any) -> bool: if isinstance(other, self.__class__): - return self._topology == other._topology + return self.eq_props() == other.eq_props() return NotImplemented def __ne__(self, other: Any) -> bool: return not self == other def __hash__(self) -> int: - return hash(self._topology) + return hash(self.eq_props()) def _repr_helper(self) -> str: def option_repr(option: str, value: Any) -> str: @@ -1139,13 +1275,16 @@ def option_repr(option: str, value: Any) -> str: return f"{option}={value!r}" # Host first... - options = [ - "host=%r" - % [ - "%s:%d" % (host, port) if port is not None else host - for host, port in self._topology_settings.seeds + if self._topology is None: + options = [f"host='mongodb+srv://{self._resolve_srv_info['fqdn']}'"] + else: + options = [ + "host=%r" + % [ + "%s:%d" % (host, port) if port is not None else host + for host, port in self._topology_settings.seeds + ] ] - ] # ... then everything in self._constructor_args... options.extend( option_repr(key, self._options._options[key]) for key in self._constructor_args @@ -1444,6 +1583,8 @@ def address(self) -> Optional[tuple[str, int]]: .. versionadded:: 3.0 """ + if self._topology is None: + self._get_topology() topology_type = self._topology._description.topology_type if ( topology_type == TOPOLOGY_TYPE.Sharded @@ -1466,6 +1607,8 @@ def primary(self) -> Optional[tuple[str, int]]: .. versionadded:: 3.0 MongoClient gained this property in version 3.0. """ + if self._topology is None: + self._get_topology() return self._topology.get_primary() # type: ignore[return-value] @property @@ -1479,6 +1622,8 @@ def secondaries(self) -> set[_Address]: .. versionadded:: 3.0 MongoClient gained this property in version 3.0. """ + if self._topology is None: + self._get_topology() return self._topology.get_secondaries() @property @@ -1489,6 +1634,8 @@ def arbiters(self) -> set[_Address]: connected to a replica set, there are no arbiters, or this client was created without the `replicaSet` option. """ + if self._topology is None: + self._get_topology() return self._topology.get_arbiters() @property @@ -1547,6 +1694,8 @@ def close(self) -> None: .. versionchanged:: 3.6 End all server sessions created by this client. """ + if self._topology is None: + return session_ids = self._topology.pop_all_sessions() if session_ids: self._end_sessions(session_ids) @@ -1559,6 +1708,12 @@ def close(self) -> None: # TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened. self._encrypter.close() self._closed = True + if not _IS_SYNC: + asyncio.gather( + self._topology.cleanup_monitors(), # type: ignore[func-returns-value] + self._kill_cursors_executor.join(), # type: ignore[func-returns-value] + return_exceptions=True, + ) if not _IS_SYNC: # Add support for contextlib.closing. @@ -1570,7 +1725,17 @@ def _get_topology(self) -> Topology: If this client was created with "connect=False", calling _get_topology launches the connection process in the background. """ + if not _IS_SYNC: + if self._loop is None: + self._loop = asyncio.get_running_loop() + elif self._loop != asyncio.get_running_loop(): + raise RuntimeError( + "Cannot use MongoClient in different event loop. MongoClient uses low-level asyncio APIs that bind it to the event loop it was created on." + ) if not self._opened: + if self._resolve_srv_info["is_srv"]: + self._resolve_srv() + self._init_background() self._topology.open() with self._lock: self._kill_cursors_executor.open() @@ -1965,7 +2130,7 @@ def _close_cursor_now( The cursor is closed synchronously on the current thread. """ if not isinstance(cursor_id, int): - raise TypeError("cursor_id must be an instance of int") + raise TypeError(f"cursor_id must be an instance of int, not {type(cursor_id)}") try: if conn_mgr: @@ -2032,15 +2197,13 @@ def _process_kill_cursors(self) -> None: for address, cursor_id, conn_mgr in pinned_cursors: try: self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False) - except asyncio.CancelledError: - raise except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: # Raise the exception when client is closed so that it # can be caught in _process_periodic_tasks raise else: - helpers_shared._handle_exception() + _log_client_error() # Don't re-open topology if it's closed and there's no pending cursors. if address_to_cursor_ids: @@ -2048,13 +2211,11 @@ def _process_kill_cursors(self) -> None: for address, cursor_ids in address_to_cursor_ids.items(): try: self._kill_cursors(cursor_ids, address, topology, session=None) - except asyncio.CancelledError: - raise except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: raise else: - helpers_shared._handle_exception() + _log_client_error() # This method is run periodically by a background thread. def _process_periodic_tasks(self) -> None: @@ -2064,13 +2225,11 @@ def _process_periodic_tasks(self) -> None: try: self._process_kill_cursors() self._topology.update_pool() - except asyncio.CancelledError: - raise except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: return else: - helpers_shared._handle_exception() + _log_client_error() def _return_server_session( self, server_session: Union[_ServerSession, _EmptyServerSession] @@ -2087,7 +2246,9 @@ def _tmp_session( """If provided session is None, lend a temporary session.""" if session is not None: if not isinstance(session, client_session.ClientSession): - raise ValueError("'session' argument must be a ClientSession or None.") + raise ValueError( + f"'session' argument must be a ClientSession or None, not {type(session)}" + ) # Don't call end_session. yield session return @@ -2235,7 +2396,9 @@ def drop_database( name = name.name if not isinstance(name, str): - raise TypeError("name_or_database must be an instance of str or a Database") + raise TypeError( + f"name_or_database must be an instance of str or a Database, not {type(name)}" + ) with self._conn_for_writes(session, operation=_Op.DROP_DATABASE) as conn: self[name]._command( @@ -2494,6 +2657,7 @@ def handle( self.completed_handshake, self.service_id, ) + assert self.client._topology is not None self.client._topology.handle_error(self.server_address, err_ctx) def __enter__(self) -> _MongoClientErrorHandler: @@ -2543,6 +2707,7 @@ def __init__( self._deprioritized_servers: list[Server] = [] self._operation = operation self._operation_id = operation_id + self._attempt_number = 0 def run(self) -> T: """Runs the supplied func() and attempts a retry @@ -2585,6 +2750,7 @@ def run(self) -> T: raise self._retrying = True self._last_error = exc + self._attempt_number += 1 else: raise @@ -2606,6 +2772,7 @@ def run(self) -> T: raise self._last_error from exc else: raise + self._attempt_number += 1 if self._bulk: self._bulk.retrying = True else: @@ -2684,6 +2851,14 @@ def _write(self) -> T: # not support sessions raise the last error. self._check_last_error() self._retryable = False + if self._retrying: + _debug_log( + _COMMAND_LOGGER, + message=f"Retrying write attempt number {self._attempt_number}", + clientId=self._client._topology_settings._topology_id, + commandName=self._operation, + operationId=self._operation_id, + ) return self._func(self._session, conn, self._retryable) # type: ignore except PyMongoError as exc: if not self._retryable: @@ -2705,6 +2880,14 @@ def _read(self) -> T: ): if self._retrying and not self._retryable: self._check_last_error() + if self._retrying: + _debug_log( + _COMMAND_LOGGER, + message=f"Retrying read attempt number {self._attempt_number}", + clientId=self._client._topology_settings._topology_id, + commandName=self._operation, + operationId=self._operation_id, + ) return self._func(self._session, self._server, conn, read_pref) # type: ignore diff --git a/pymongo/synchronous/monitor.py b/pymongo/synchronous/monitor.py index df4130d4ab..f41040801f 100644 --- a/pymongo/synchronous/monitor.py +++ b/pymongo/synchronous/monitor.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -21,11 +21,11 @@ import logging import time import weakref -from typing import TYPE_CHECKING, Any, Mapping, Optional, cast +from typing import TYPE_CHECKING, Any, Optional from pymongo import common, periodic_executor from pymongo._csot import MovingMinimum -from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled +from pymongo.errors import NetworkTimeout, _OperationCancelled from pymongo.hello import Hello from pymongo.lock import _create_lock from pymongo.logger import _SDAM_LOGGER, _debug_log, _SDAMStatusMessage @@ -33,10 +33,14 @@ from pymongo.pool_options import _is_faas from pymongo.read_preferences import MovingAverage from pymongo.server_description import ServerDescription -from pymongo.srv_resolver import _SrvResolver +from pymongo.synchronous.srv_resolver import _SrvResolver if TYPE_CHECKING: - from pymongo.synchronous.pool import Connection, Pool, _CancellationContext + from pymongo.synchronous.pool import ( # type: ignore[attr-defined] + Connection, + Pool, + _CancellationContext, + ) from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology @@ -112,9 +116,9 @@ def close(self) -> None: """ self.gc_safe_close() - def join(self, timeout: Optional[int] = None) -> None: + def join(self) -> None: """Wait for the monitor to stop.""" - self._executor.join(timeout) + self._executor.join() def request_check(self) -> None: """If the monitor is sleeping, wake it soon.""" @@ -189,6 +193,9 @@ def gc_safe_close(self) -> None: self._rtt_monitor.gc_safe_close() self.cancel_check() + def join(self) -> None: + asyncio.gather(self._executor.join(), self._rtt_monitor.join(), return_exceptions=True) # type: ignore[func-returns-value] + def close(self) -> None: self.gc_safe_close() self._rtt_monitor.close() @@ -250,15 +257,7 @@ def _check_server(self) -> ServerDescription: self._conn_id = None start = time.monotonic() try: - try: - return self._check_once() - except (OperationFailure, NotPrimaryError) as exc: - # Update max cluster time even when hello fails. - details = cast(Mapping[str, Any], exc.details) - self._topology.receive_cluster_time(details.get("$clusterTime")) - raise - except asyncio.CancelledError: - raise + return self._check_once() except ReferenceError: raise except Exception as error: @@ -273,6 +272,7 @@ def _check_server(self) -> ServerDescription: if _SDAM_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _SDAM_LOGGER, + message=_SDAMStatusMessage.HEARTBEAT_FAIL, topologyId=self._topology._topology_id, serverHost=address[0], serverPort=address[1], @@ -280,7 +280,6 @@ def _check_server(self) -> ServerDescription: durationMS=duration * 1000, failure=error, driverConnectionId=self._conn_id, - message=_SDAMStatusMessage.HEARTBEAT_FAIL, ) self._reset_connection() if isinstance(error, _OperationCancelled): @@ -312,13 +311,13 @@ def _check_once(self) -> ServerDescription: if _SDAM_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _SDAM_LOGGER, + message=_SDAMStatusMessage.HEARTBEAT_START, topologyId=self._topology._topology_id, driverConnectionId=conn.id, serverConnectionId=conn.server_connection_id, serverHost=address[0], serverPort=address[1], awaited=awaited, - message=_SDAMStatusMessage.HEARTBEAT_START, ) self._cancel_context = conn.cancel_context @@ -338,6 +337,7 @@ def _check_once(self) -> ServerDescription: if _SDAM_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _SDAM_LOGGER, + message=_SDAMStatusMessage.HEARTBEAT_SUCCESS, topologyId=self._topology._topology_id, driverConnectionId=conn.id, serverConnectionId=conn.server_connection_id, @@ -346,7 +346,6 @@ def _check_once(self) -> ServerDescription: awaited=awaited, durationMS=round_trip_time * 1000, reply=response.document, - message=_SDAMStatusMessage.HEARTBEAT_SUCCESS, ) return sd @@ -355,7 +354,6 @@ def _check_with_socket(self, conn: Connection) -> tuple[Hello, float]: Can raise ConnectionFailure or OperationFailure. """ - cluster_time = self._topology.max_cluster_time() start = time.monotonic() if conn.more_to_come: # Read the next streaming hello (MongoDB 4.4+). @@ -365,13 +363,12 @@ def _check_with_socket(self, conn: Connection) -> tuple[Hello, float]: ): # Initiate streaming hello (MongoDB 4.4+). response = conn._hello( - cluster_time, self._server_description.topology_version, self._settings.heartbeat_frequency, ) else: # New connection handshake or polling hello (MongoDB <4.4). - response = conn._hello(cluster_time, None, None) + response = conn._hello(None, None) duration = _monotonic_duration(start) return response, duration @@ -424,8 +421,6 @@ def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]: if len(seedlist) == 0: # As per the spec: this should be treated as a failure. raise Exception - except asyncio.CancelledError: - raise except Exception: # As per the spec, upon encountering an error: # - An error must not be raised @@ -489,8 +484,6 @@ def _run(self) -> None: except ReferenceError: # Topology was garbage-collected. self.close() - except asyncio.CancelledError: - raise except Exception: self._pool.reset() diff --git a/pymongo/synchronous/network.py b/pymongo/synchronous/network.py index 7206dca735..9559a5a542 100644 --- a/pymongo/synchronous/network.py +++ b/pymongo/synchronous/network.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -17,7 +17,6 @@ import datetime import logging -import time from typing import ( TYPE_CHECKING, Any, @@ -31,20 +30,16 @@ from bson import _decode_all_selective from pymongo import _csot, helpers_shared, message -from pymongo.common import MAX_MESSAGE_SIZE -from pymongo.compression_support import _NO_COMPRESSION, decompress +from pymongo.compression_support import _NO_COMPRESSION from pymongo.errors import ( NotPrimaryError, OperationFailure, - ProtocolError, ) from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log -from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply +from pymongo.message import _OpMsg from pymongo.monitoring import _is_speculative_authenticate from pymongo.network_layer import ( - _UNPACK_COMPRESSION_HEADER, - _UNPACK_HEADER, - receive_data, + receive_message, sendall, ) @@ -168,8 +163,8 @@ def command( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.STARTED, + clientId=client._topology_settings._topology_id, command=spec, commandName=next(iter(spec)), databaseName=dbname, @@ -194,7 +189,7 @@ def command( ) try: - sendall(conn.conn, msg) + sendall(conn.conn.get_conn, msg) if use_op_msg and unacknowledged: # Unacknowledged, fake a successful command response. reply = None @@ -207,6 +202,10 @@ def command( ) response_doc = unpacked_docs[0] + if not conn.ready: + cluster_time = response_doc.get("$clusterTime") + if cluster_time: + conn._cluster_time = cluster_time if client: client._process_response(response_doc, session) if check: @@ -226,8 +225,8 @@ def command( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.FAILED, + clientId=client._topology_settings._topology_id, durationMS=duration, failure=failure, commandName=next(iter(spec)), @@ -260,8 +259,8 @@ def command( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.SUCCEEDED, + clientId=client._topology_settings._topology_id, durationMS=duration, reply=response_doc, commandName=next(iter(spec)), @@ -297,45 +296,3 @@ def command( ) return response_doc # type: ignore[return-value] - - -def receive_message( - conn: Connection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE -) -> Union[_OpReply, _OpMsg]: - """Receive a raw BSON message or raise socket.error.""" - if _csot.get_timeout(): - deadline = _csot.get_deadline() - else: - timeout = conn.conn.gettimeout() - if timeout: - deadline = time.monotonic() + timeout - else: - deadline = None - # Ignore the response's request id. - length, _, response_to, op_code = _UNPACK_HEADER(receive_data(conn, 16, deadline)) - # No request_id for exhaust cursor "getMore". - if request_id is not None: - if request_id != response_to: - raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}") - if length <= 16: - raise ProtocolError( - f"Message length ({length!r}) not longer than standard message header size (16)" - ) - if length > max_message_size: - raise ProtocolError( - f"Message length ({length!r}) is larger than server max " - f"message size ({max_message_size!r})" - ) - if op_code == 2012: - op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(receive_data(conn, 9, deadline)) - data = decompress(receive_data(conn, length - 25, deadline), compressor_id) - else: - data = receive_data(conn, length - 16, deadline) - - try: - unpack_reply = _UNPACK_REPLY[op_code] - except KeyError: - raise ProtocolError( - f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}" - ) from None - return unpack_reply(data) diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 05f930d480..44aec31a86 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -17,11 +17,8 @@ import asyncio import collections import contextlib -import functools import logging import os -import socket -import ssl import sys import time import weakref @@ -49,16 +46,13 @@ from pymongo.errors import ( # type:ignore[attr-defined] AutoReconnect, ConfigurationError, - ConnectionFailure, DocumentTooLarge, ExecutionTimeout, InvalidOperation, - NetworkTimeout, NotPrimaryError, OperationFailure, PyMongoError, WaitQueueTimeoutError, - _CertificateError, ) from pymongo.hello import Hello, HelloCompat from pymongo.lock import ( @@ -76,16 +70,23 @@ ConnectionCheckOutFailedReason, ConnectionClosedReason, ) -from pymongo.network_layer import sendall +from pymongo.network_layer import NetworkingInterface, receive_message, sendall from pymongo.pool_options import PoolOptions +from pymongo.pool_shared import ( + SSLErrors, + _CancellationContext, + _configured_socket_interface, + _get_timeout_details, + _raise_connection_failure, + format_timeout_details, +) from pymongo.read_preferences import ReadPreference from pymongo.server_api import _add_to_command from pymongo.server_type import SERVER_TYPE from pymongo.socket_checker import SocketChecker -from pymongo.ssl_support import HAS_SNI, SSLError from pymongo.synchronous.client_session import _validate_session_write_concern -from pymongo.synchronous.helpers import _getaddrinfo, _handle_reauth -from pymongo.synchronous.network import command, receive_message +from pymongo.synchronous.helpers import _handle_reauth +from pymongo.synchronous.network import command if TYPE_CHECKING: from bson import CodecOptions @@ -96,13 +97,12 @@ ZstdContext, ) from pymongo.message import _OpMsg, _OpReply - from pymongo.pyopenssl_context import _sslConn from pymongo.read_concern import ReadConcern from pymongo.read_preferences import _ServerMode from pymongo.synchronous.auth import _AuthContext from pymongo.synchronous.client_session import ClientSession from pymongo.synchronous.mongo_client import MongoClient, _MongoClientErrorHandler - from pymongo.typings import ClusterTime, _Address, _CollationIn + from pymongo.typings import _Address, _CollationIn from pymongo.write_concern import WriteConcern try: @@ -123,133 +123,6 @@ def _set_non_inheritable_non_atomic(fd: int) -> None: # noqa: ARG001 _IS_SYNC = True -_MAX_TCP_KEEPIDLE = 120 -_MAX_TCP_KEEPINTVL = 10 -_MAX_TCP_KEEPCNT = 9 - -if sys.platform == "win32": - try: - import _winreg as winreg - except ImportError: - import winreg - - def _query(key, name, default): - try: - value, _ = winreg.QueryValueEx(key, name) - # Ensure the value is a number or raise ValueError. - return int(value) - except (OSError, ValueError): - # QueryValueEx raises OSError when the key does not exist (i.e. - # the system is using the Windows default value). - return default - - try: - with winreg.OpenKey( - winreg.HKEY_LOCAL_MACHINE, r"SYSTEM\CurrentControlSet\Services\Tcpip\Parameters" - ) as key: - _WINDOWS_TCP_IDLE_MS = _query(key, "KeepAliveTime", 7200000) - _WINDOWS_TCP_INTERVAL_MS = _query(key, "KeepAliveInterval", 1000) - except OSError: - # We could not check the default values because winreg.OpenKey failed. - # Assume the system is using the default values. - _WINDOWS_TCP_IDLE_MS = 7200000 - _WINDOWS_TCP_INTERVAL_MS = 1000 - - def _set_keepalive_times(sock): - idle_ms = min(_WINDOWS_TCP_IDLE_MS, _MAX_TCP_KEEPIDLE * 1000) - interval_ms = min(_WINDOWS_TCP_INTERVAL_MS, _MAX_TCP_KEEPINTVL * 1000) - if idle_ms < _WINDOWS_TCP_IDLE_MS or interval_ms < _WINDOWS_TCP_INTERVAL_MS: - sock.ioctl(socket.SIO_KEEPALIVE_VALS, (1, idle_ms, interval_ms)) - -else: - - def _set_tcp_option(sock: socket.socket, tcp_option: str, max_value: int) -> None: - if hasattr(socket, tcp_option): - sockopt = getattr(socket, tcp_option) - try: - # PYTHON-1350 - NetBSD doesn't implement getsockopt for - # TCP_KEEPIDLE and friends. Don't attempt to set the - # values there. - default = sock.getsockopt(socket.IPPROTO_TCP, sockopt) - if default > max_value: - sock.setsockopt(socket.IPPROTO_TCP, sockopt, max_value) - except OSError: - pass - - def _set_keepalive_times(sock: socket.socket) -> None: - _set_tcp_option(sock, "TCP_KEEPIDLE", _MAX_TCP_KEEPIDLE) - _set_tcp_option(sock, "TCP_KEEPINTVL", _MAX_TCP_KEEPINTVL) - _set_tcp_option(sock, "TCP_KEEPCNT", _MAX_TCP_KEEPCNT) - - -def _raise_connection_failure( - address: Any, - error: Exception, - msg_prefix: Optional[str] = None, - timeout_details: Optional[dict[str, float]] = None, -) -> NoReturn: - """Convert a socket.error to ConnectionFailure and raise it.""" - host, port = address - # If connecting to a Unix socket, port will be None. - if port is not None: - msg = "%s:%d: %s" % (host, port, error) - else: - msg = f"{host}: {error}" - if msg_prefix: - msg = msg_prefix + msg - if "configured timeouts" not in msg: - msg += format_timeout_details(timeout_details) - if isinstance(error, socket.timeout): - raise NetworkTimeout(msg) from error - elif isinstance(error, SSLError) and "timed out" in str(error): - # Eventlet does not distinguish TLS network timeouts from other - # SSLErrors (https://github.com/eventlet/eventlet/issues/692). - # Luckily, we can work around this limitation because the phrase - # 'timed out' appears in all the timeout related SSLErrors raised. - raise NetworkTimeout(msg) from error - else: - raise AutoReconnect(msg) from error - - -def _get_timeout_details(options: PoolOptions) -> dict[str, float]: - details = {} - timeout = _csot.get_timeout() - socket_timeout = options.socket_timeout - connect_timeout = options.connect_timeout - if timeout: - details["timeoutMS"] = timeout * 1000 - if socket_timeout and not timeout: - details["socketTimeoutMS"] = socket_timeout * 1000 - if connect_timeout: - details["connectTimeoutMS"] = connect_timeout * 1000 - return details - - -def format_timeout_details(details: Optional[dict[str, float]]) -> str: - result = "" - if details: - result += " (configured timeouts:" - for timeout in ["socketTimeoutMS", "timeoutMS", "connectTimeoutMS"]: - if timeout in details: - result += f" {timeout}: {details[timeout]}ms," - result = result[:-1] - result += ")" - return result - - -class _CancellationContext: - def __init__(self) -> None: - self._cancelled = False - - def cancel(self) -> None: - """Cancel this context.""" - self._cancelled = True - - @property - def cancelled(self) -> bool: - """Was cancel called?""" - return self._cancelled - class Connection: """Store a connection with some metadata. @@ -261,7 +134,11 @@ class Connection: """ def __init__( - self, conn: Union[socket.socket, _sslConn], pool: Pool, address: tuple[str, int], id: int + self, + conn: NetworkingInterface, + pool: Pool, + address: tuple[str, int], + id: int, ): self.pool_ref = weakref.ref(pool) self.conn = conn @@ -310,13 +187,15 @@ def __init__( self.connect_rtt = 0.0 self._client_id = pool._client_id self.creation_time = time.monotonic() + # For gossiping $clusterTime from the connection handshake to the client. + self._cluster_time = None def set_conn_timeout(self, timeout: Optional[float]) -> None: """Cache last timeout to avoid duplicate calls to conn.settimeout.""" if timeout == self.last_timeout: return self.last_timeout = timeout - self.conn.settimeout(timeout) + self.conn.get_conn.settimeout(timeout) def apply_timeout( self, client: MongoClient, cmd: Optional[MutableMapping[str, Any]] @@ -374,11 +253,10 @@ def hello_cmd(self) -> dict[str, Any]: return {HelloCompat.LEGACY_CMD: 1, "helloOk": True} def hello(self) -> Hello: - return self._hello(None, None, None) + return self._hello(None, None) def _hello( self, - cluster_time: Optional[ClusterTime], topology_version: Optional[Any], heartbeat_frequency: Optional[int], ) -> Hello[dict[str, Any]]: @@ -401,9 +279,6 @@ def _hello( if self.opts.connect_timeout: self.set_conn_timeout(self.opts.connect_timeout + heartbeat_frequency) - if not performing_handshake and cluster_time is not None: - cmd["$clusterTime"] = cluster_time - creds = self.opts._credentials if creds: if creds.mechanism == "DEFAULT" and creds.username: @@ -559,7 +434,7 @@ def command( ) except (OperationFailure, NotPrimaryError): raise - # Catch socket.error, KeyboardInterrupt, etc. and close ourselves. + # Catch socket.error, KeyboardInterrupt, CancelledError, etc. and close ourselves. except BaseException as error: self._raise_connection_failure(error) @@ -575,7 +450,8 @@ def send_message(self, message: bytes, max_doc_size: int) -> None: ) try: - sendall(self.conn, message) + sendall(self.conn.get_conn, message) + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: self._raise_connection_failure(error) @@ -586,6 +462,7 @@ def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _OpMsg]: """ try: return receive_message(self, request_id, self.max_message_size) + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: self._raise_connection_failure(error) @@ -652,8 +529,8 @@ def authenticate(self, reauthenticate: bool = False) -> None: if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _CONNECTION_LOGGER, - clientId=self._client_id, message=_ConnectionStatusMessage.CONN_READY, + clientId=self._client_id, serverHost=self.address[0], serverPort=self.address[1], driverConnectionId=self.id, @@ -683,8 +560,8 @@ def close_conn(self, reason: Optional[str]) -> None: if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _CONNECTION_LOGGER, - clientId=self._client_id, message=_ConnectionStatusMessage.CONN_CLOSED, + clientId=self._client_id, serverHost=self.address[0], serverPort=self.address[1], driverConnectionId=self.id, @@ -702,14 +579,15 @@ def _close_conn(self) -> None: # shutdown. try: self.conn.close() - except asyncio.CancelledError: - raise except Exception: # noqa: S110 pass def conn_closed(self) -> bool: """Return True if we know socket has been closed, False otherwise.""" - return self.socket_checker.socket_closed(self.conn) + if _IS_SYNC: + return self.socket_checker.socket_closed(self.conn.get_conn) + else: + return self.conn.is_closing() def send_cluster_time( self, @@ -758,7 +636,7 @@ def _raise_connection_failure(self, error: BaseException) -> NoReturn: reason = ConnectionClosedReason.ERROR self.close_conn(reason) # SSLError from PyOpenSSL inherits directly from Exception. - if isinstance(error, (IOError, OSError, SSLError)): + if isinstance(error, (IOError, OSError, *SSLErrors)): details = _get_timeout_details(self.opts) _raise_connection_failure(self.address, error, timeout_details=details) else: @@ -781,143 +659,6 @@ def __repr__(self) -> str: ) -def _create_connection(address: _Address, options: PoolOptions) -> socket.socket: - """Given (host, port) and PoolOptions, connect and return a socket object. - - Can raise socket.error. - - This is a modified version of create_connection from CPython >= 2.7. - """ - host, port = address - - # Check if dealing with a unix domain socket - if host.endswith(".sock"): - if not hasattr(socket, "AF_UNIX"): - raise ConnectionFailure("UNIX-sockets are not supported on this system") - sock = socket.socket(socket.AF_UNIX) - # SOCK_CLOEXEC not supported for Unix sockets. - _set_non_inheritable_non_atomic(sock.fileno()) - try: - sock.connect(host) - return sock - except OSError: - sock.close() - raise - - # Don't try IPv6 if we don't support it. Also skip it if host - # is 'localhost' (::1 is fine). Avoids slow connect issues - # like PYTHON-356. - family = socket.AF_INET - if socket.has_ipv6 and host != "localhost": - family = socket.AF_UNSPEC - - err = None - for res in _getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined] - af, socktype, proto, dummy, sa = res - # SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited - # number of platforms (newer Linux and *BSD). Starting with CPython 3.4 - # all file descriptors are created non-inheritable. See PEP 446. - try: - sock = socket.socket(af, socktype | getattr(socket, "SOCK_CLOEXEC", 0), proto) - except OSError: - # Can SOCK_CLOEXEC be defined even if the kernel doesn't support - # it? - sock = socket.socket(af, socktype, proto) - # Fallback when SOCK_CLOEXEC isn't available. - _set_non_inheritable_non_atomic(sock.fileno()) - try: - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - # CSOT: apply timeout to socket connect. - timeout = _csot.remaining() - if timeout is None: - timeout = options.connect_timeout - elif timeout <= 0: - raise socket.timeout("timed out") - sock.settimeout(timeout) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True) - _set_keepalive_times(sock) - sock.connect(sa) - return sock - except OSError as e: - err = e - sock.close() - - if err is not None: - raise err - else: - # This likely means we tried to connect to an IPv6 only - # host with an OS/kernel or Python interpreter that doesn't - # support IPv6. The test case is Jython2.5.1 which doesn't - # support IPv6 at all. - raise OSError("getaddrinfo failed") - - -def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket.socket, _sslConn]: - """Given (host, port) and PoolOptions, return a configured socket. - - Can raise socket.error, ConnectionFailure, or _CertificateError. - - Sets socket's SSL and timeout options. - """ - sock = _create_connection(address, options) - ssl_context = options._ssl_context - - if ssl_context is None: - sock.settimeout(options.socket_timeout) - return sock - - host = address[0] - try: - # We have to pass hostname / ip address to wrap_socket - # to use SSLContext.check_hostname. - if HAS_SNI: - if _IS_SYNC: - ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host) - else: - if hasattr(ssl_context, "a_wrap_socket"): - ssl_sock = ssl_context.a_wrap_socket(sock, server_hostname=host) # type: ignore[assignment, misc] - else: - loop = asyncio.get_running_loop() - ssl_sock = loop.run_in_executor( - None, - functools.partial(ssl_context.wrap_socket, sock, server_hostname=host), # type: ignore[assignment, misc] - ) - else: - if _IS_SYNC: - ssl_sock = ssl_context.wrap_socket(sock) - else: - if hasattr(ssl_context, "a_wrap_socket"): - ssl_sock = ssl_context.a_wrap_socket(sock) # type: ignore[assignment, misc] - else: - loop = asyncio.get_running_loop() - ssl_sock = loop.run_in_executor(None, ssl_context.wrap_socket, sock) # type: ignore[assignment, misc] - except _CertificateError: - sock.close() - # Raise _CertificateError directly like we do after match_hostname - # below. - raise - except (OSError, SSLError) as exc: - sock.close() - # We raise AutoReconnect for transient and permanent SSL handshake - # failures alike. Permanent handshake failures, like protocol - # mismatch, will be turned into ServerSelectionTimeoutErrors later. - details = _get_timeout_details(options) - _raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details) - if ( - ssl_context.verify_mode - and not ssl_context.check_hostname - and not options.tls_allow_invalid_hostnames - ): - try: - ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined] - except _CertificateError: - ssl_sock.close() - raise - - ssl_sock.settimeout(options.socket_timeout) - return ssl_sock - - class _PoolClosedError(PyMongoError): """Internal error raised when a thread tries to get a connection from a closed pool. @@ -962,7 +703,7 @@ class PoolState: # Do *not* explicitly inherit from object or Jython won't call __del__ -# http://bugs.jython.org/issue1057 +# https://bugs.jython.org/issue1057 class Pool: def __init__( self, @@ -1035,8 +776,8 @@ def __init__( if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _CONNECTION_LOGGER, - clientId=self._client_id, message=_ConnectionStatusMessage.POOL_CREATED, + clientId=self._client_id, serverHost=self.address[0], serverPort=self.address[1], **self.opts.non_default_options, @@ -1061,8 +802,8 @@ def ready(self) -> None: if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _CONNECTION_LOGGER, - clientId=self._client_id, message=_ConnectionStatusMessage.POOL_READY, + clientId=self._client_id, serverHost=self.address[0], serverPort=self.address[1], ) @@ -1118,16 +859,22 @@ def _reset( # PoolClosedEvent but that reset() SHOULD close sockets *after* # publishing the PoolClearedEvent. if close: - for conn in sockets: - conn.close_conn(ConnectionClosedReason.POOL_CLOSED) + if not _IS_SYNC: + asyncio.gather( + *[conn.close_conn(ConnectionClosedReason.POOL_CLOSED) for conn in sockets], + return_exceptions=True, + ) + else: + for conn in sockets: + conn.close_conn(ConnectionClosedReason.POOL_CLOSED) if self.enabled_for_cmap: assert listeners is not None listeners.publish_pool_closed(self.address) if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _CONNECTION_LOGGER, - clientId=self._client_id, message=_ConnectionStatusMessage.POOL_CLOSED, + clientId=self._client_id, serverHost=self.address[0], serverPort=self.address[1], ) @@ -1143,14 +890,20 @@ def _reset( if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _CONNECTION_LOGGER, - clientId=self._client_id, message=_ConnectionStatusMessage.POOL_CLEARED, + clientId=self._client_id, serverHost=self.address[0], serverPort=self.address[1], serviceId=service_id, ) - for conn in sockets: - conn.close_conn(ConnectionClosedReason.STALE) + if not _IS_SYNC: + asyncio.gather( + *[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets], + return_exceptions=True, + ) + else: + for conn in sockets: + conn.close_conn(ConnectionClosedReason.STALE) def update_is_writable(self, is_writable: Optional[bool]) -> None: """Updates the is_writable attribute on all sockets currently in the @@ -1187,12 +940,20 @@ def remove_stale_sockets(self, reference_generation: int) -> None: return if self.opts.max_idle_time_seconds is not None: + close_conns = [] with self.lock: while ( self.conns and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds ): - conn = self.conns.pop() + close_conns.append(self.conns.pop()) + if not _IS_SYNC: + asyncio.gather( + *[conn.close_conn(ConnectionClosedReason.IDLE) for conn in close_conns], + return_exceptions=True, + ) + else: + for conn in close_conns: conn.close_conn(ConnectionClosedReason.IDLE) while True: @@ -1213,14 +974,18 @@ def remove_stale_sockets(self, reference_generation: int) -> None: self._pending += 1 incremented = True conn = self.connect() + close_conn = False with self.lock: # Close connection and return if the pool was reset during # socket creation or while acquiring the pool lock. if self.gen.get_overall() != reference_generation: - conn.close_conn(ConnectionClosedReason.STALE) - return - self.conns.appendleft(conn) - self.active_contexts.discard(conn.cancel_context) + close_conn = True + if not close_conn: + self.conns.appendleft(conn) + self.active_contexts.discard(conn.cancel_context) + if close_conn: + conn.close_conn(ConnectionClosedReason.STALE) + return finally: if incremented: # Notify after adding the socket to the pool. @@ -1254,15 +1019,16 @@ def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connect if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _CONNECTION_LOGGER, - clientId=self._client_id, message=_ConnectionStatusMessage.CONN_CREATED, + clientId=self._client_id, serverHost=self.address[0], serverPort=self.address[1], driverConnectionId=conn_id, ) try: - sock = _configured_socket(self.address, self.opts) + networking_interface = _configured_socket_interface(self.address, self.opts) + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: with self.lock: self.active_contexts.discard(tmp_context) @@ -1274,21 +1040,21 @@ def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connect if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _CONNECTION_LOGGER, - clientId=self._client_id, message=_ConnectionStatusMessage.CONN_CLOSED, + clientId=self._client_id, serverHost=self.address[0], serverPort=self.address[1], driverConnectionId=conn_id, reason=_verbose_connection_error_reason(ConnectionClosedReason.ERROR), error=ConnectionClosedReason.ERROR, ) - if isinstance(error, (IOError, OSError, SSLError)): + if isinstance(error, (IOError, OSError, *SSLErrors)): details = _get_timeout_details(self.opts) _raise_connection_failure(self.address, error, timeout_details=details) raise - conn = Connection(sock, self, self.address, conn_id) # type: ignore[arg-type] + conn = Connection(networking_interface, self, self.address, conn_id) # type: ignore[arg-type] with self.lock: self.active_contexts.add(conn.cancel_context) self.active_contexts.discard(tmp_context) @@ -1302,12 +1068,16 @@ def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connect handler.contribute_socket(conn, completed_handshake=False) conn.authenticate() + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException: with self.lock: self.active_contexts.discard(conn.cancel_context) conn.close_conn(ConnectionClosedReason.ERROR) raise + if handler: + handler.client._topology.receive_cluster_time(conn._cluster_time) + return conn @contextlib.contextmanager @@ -1337,8 +1107,8 @@ def checkout( if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _CONNECTION_LOGGER, - clientId=self._client_id, message=_ConnectionStatusMessage.CHECKOUT_STARTED, + clientId=self._client_id, serverHost=self.address[0], serverPort=self.address[1], ) @@ -1352,8 +1122,8 @@ def checkout( if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _CONNECTION_LOGGER, - clientId=self._client_id, message=_ConnectionStatusMessage.CHECKOUT_SUCCEEDED, + clientId=self._client_id, serverHost=self.address[0], serverPort=self.address[1], driverConnectionId=conn.id, @@ -1363,6 +1133,7 @@ def checkout( with self.lock: self.active_contexts.add(conn.cancel_context) yield conn + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException: # Exception in caller. Ensure the connection gets returned. # Note that when pinned is True, the session owns the @@ -1400,8 +1171,8 @@ def _raise_if_not_ready(self, checkout_started_time: float, emit_event: bool) -> if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _CONNECTION_LOGGER, - clientId=self._client_id, message=_ConnectionStatusMessage.CHECKOUT_FAILED, + clientId=self._client_id, serverHost=self.address[0], serverPort=self.address[1], reason="An error occurred while trying to establish a new connection", @@ -1434,8 +1205,8 @@ def _get_conn( if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _CONNECTION_LOGGER, - clientId=self._client_id, message=_ConnectionStatusMessage.CHECKOUT_FAILED, + clientId=self._client_id, serverHost=self.address[0], serverPort=self.address[1], reason="Connection pool was closed", @@ -1509,6 +1280,7 @@ def _get_conn( with self._max_connecting_cond: self._pending -= 1 self._max_connecting_cond.notify() + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException: if conn: # We checked out a socket but authentication failed. @@ -1529,8 +1301,8 @@ def _get_conn( if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _CONNECTION_LOGGER, - clientId=self._client_id, message=_ConnectionStatusMessage.CHECKOUT_FAILED, + clientId=self._client_id, serverHost=self.address[0], serverPort=self.address[1], reason="An error occurred while trying to establish a new connection", @@ -1562,8 +1334,8 @@ def checkin(self, conn: Connection) -> None: if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _CONNECTION_LOGGER, - clientId=self._client_id, message=_ConnectionStatusMessage.CHECKEDIN, + clientId=self._client_id, serverHost=self.address[0], serverPort=self.address[1], driverConnectionId=conn.id, @@ -1583,8 +1355,8 @@ def checkin(self, conn: Connection) -> None: if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _CONNECTION_LOGGER, - clientId=self._client_id, message=_ConnectionStatusMessage.CONN_CLOSED, + clientId=self._client_id, serverHost=self.address[0], serverPort=self.address[1], driverConnectionId=conn.id, @@ -1592,17 +1364,20 @@ def checkin(self, conn: Connection) -> None: error=ConnectionClosedReason.ERROR, ) else: + close_conn = False with self.lock: # Hold the lock to ensure this section does not race with # Pool.reset(). if self.stale_generation(conn.generation, conn.service_id): - conn.close_conn(ConnectionClosedReason.STALE) + close_conn = True else: conn.update_last_checkin_time() conn.update_is_writable(bool(self.is_writable)) self.conns.appendleft(conn) # Notify any threads waiting to create a connection. self._max_connecting_cond.notify() + if close_conn: + conn.close_conn(ConnectionClosedReason.STALE) with self.size_cond: if txn: @@ -1661,8 +1436,8 @@ def _raise_wait_queue_timeout(self, checkout_started_time: float) -> NoReturn: if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _CONNECTION_LOGGER, - clientId=self._client_id, message=_ConnectionStatusMessage.CHECKOUT_FAILED, + clientId=self._client_id, serverHost=self.address[0], serverPort=self.address[1], reason="Wait queue timeout elapsed without a connection becoming available", @@ -1693,5 +1468,6 @@ def __del__(self) -> None: # Avoid ResourceWarnings in Python 3 # Close all sockets without calling reset() or close() because it is # not safe to acquire a lock in __del__. - for conn in self.conns: - conn.close_conn(None) + if _IS_SYNC: + for conn in self.conns: + conn.close_conn(None) diff --git a/pymongo/synchronous/server.py b/pymongo/synchronous/server.py index ed48cc6cc8..c3643ba815 100644 --- a/pymongo/synchronous/server.py +++ b/pymongo/synchronous/server.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -108,10 +108,10 @@ def close(self) -> None: if _SDAM_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _SDAM_LOGGER, + message=_SDAMStatusMessage.STOP_SERVER, topologyId=self._topology_id, serverHost=self._description.address[0], serverPort=self._description.address[1], - message=_SDAMStatusMessage.STOP_SERVER, ) self._monitor.close() @@ -173,8 +173,8 @@ def run_operation( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.STARTED, + clientId=client._topology_settings._topology_id, command=cmd, commandName=next(iter(cmd)), databaseName=dbn, @@ -234,8 +234,8 @@ def run_operation( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.FAILED, + clientId=client._topology_settings._topology_id, durationMS=duration, failure=failure, commandName=next(iter(cmd)), @@ -278,8 +278,8 @@ def run_operation( if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.SUCCEEDED, + clientId=client._topology_settings._topology_id, durationMS=duration, reply=res, commandName=next(iter(cmd)), diff --git a/pymongo/synchronous/settings.py b/pymongo/synchronous/settings.py index 040776713f..61b86fa18d 100644 --- a/pymongo/synchronous/settings.py +++ b/pymongo/synchronous/settings.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -51,6 +51,7 @@ def __init__( srv_service_name: str = common.SRV_SERVICE_NAME, srv_max_hosts: int = 0, server_monitoring_mode: str = common.SERVER_MONITORING_MODE, + topology_id: Optional[ObjectId] = None, ): """Represent MongoClient's configuration. @@ -78,8 +79,10 @@ def __init__( self._srv_service_name = srv_service_name self._srv_max_hosts = srv_max_hosts or 0 self._server_monitoring_mode = server_monitoring_mode - - self._topology_id = ObjectId() + if topology_id is not None: + self._topology_id = topology_id + else: + self._topology_id = ObjectId() # Store the allocation traceback to catch unclosed clients in the # test suite. self._stack = "".join(traceback.format_stack()[:-2]) diff --git a/pymongo/srv_resolver.py b/pymongo/synchronous/srv_resolver.py similarity index 80% rename from pymongo/srv_resolver.py rename to pymongo/synchronous/srv_resolver.py index 5be6cb98db..0817c6dcd7 100644 --- a/pymongo/srv_resolver.py +++ b/pymongo/synchronous/srv_resolver.py @@ -25,6 +25,8 @@ if TYPE_CHECKING: from dns import resolver +_IS_SYNC = True + def _have_dnspython() -> bool: try: @@ -45,13 +47,23 @@ def maybe_decode(text: Union[str, bytes]) -> str: # PYTHON-2667 Lazily call dns.resolver methods for compatibility with eventlet. def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer: - from dns import resolver + if _IS_SYNC: + from dns import resolver - if hasattr(resolver, "resolve"): - # dnspython >= 2 - return resolver.resolve(*args, **kwargs) - # dnspython 1.X - return resolver.query(*args, **kwargs) + if hasattr(resolver, "resolve"): + # dnspython >= 2 + return resolver.resolve(*args, **kwargs) + # dnspython 1.X + return resolver.query(*args, **kwargs) + else: + from dns import asyncresolver + + if hasattr(asyncresolver, "resolve"): + # dnspython >= 2 + return asyncresolver.resolve(*args, **kwargs) # type:ignore[return-value] + raise ConfigurationError( + "Upgrade to dnspython version >= 2.0 to use MongoClient with mongodb+srv:// connections." + ) _INVALID_HOST_MSG = ( @@ -78,14 +90,13 @@ def __init__( raise ConfigurationError(_INVALID_HOST_MSG % ("an IP address",)) except ValueError: pass - try: - self.__plist = self.__fqdn.split(".")[1:] + split_fqdn = self.__fqdn.split(".") + self.__plist = split_fqdn[1:] if len(split_fqdn) > 2 else split_fqdn except Exception: raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) from None self.__slen = len(self.__plist) - if self.__slen < 2: - raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) + self.nparts = len(split_fqdn) def get_options(self) -> Optional[str]: from dns import resolver @@ -127,8 +138,13 @@ def _get_srv_response_and_hosts( # Validate hosts for node in nodes: + srv_host = node[0].lower() + if self.__fqdn == srv_host and self.nparts < 3: + raise ConfigurationError( + "Invalid SRV host: return address is identical to SRV hostname" + ) try: - nlist = node[0].lower().split(".")[1:][-self.__slen :] + nlist = srv_host.split(".")[1:][-self.__slen :] except Exception: raise ConfigurationError(f"Invalid SRV host: {node[0]}") from None if self.__plist != nlist: diff --git a/pymongo/synchronous/topology.py b/pymongo/synchronous/topology.py index b03269ae43..1e99adf726 100644 --- a/pymongo/synchronous/topology.py +++ b/pymongo/synchronous/topology.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -16,6 +16,7 @@ from __future__ import annotations +import asyncio import logging import os import queue @@ -36,6 +37,7 @@ OperationFailure, PyMongoError, ServerSelectionTimeoutError, + WaitQueueTimeoutError, WriteError, ) from pymongo.hello import Hello @@ -61,7 +63,7 @@ writable_server_selector, ) from pymongo.synchronous.client_session import _ServerSession, _ServerSessionPool -from pymongo.synchronous.monitor import SrvMonitor +from pymongo.synchronous.monitor import MonitorBase, SrvMonitor from pymongo.synchronous.pool import Pool from pymongo.synchronous.server import Server from pymongo.topology_description import ( @@ -118,8 +120,8 @@ def __init__(self, topology_settings: TopologySettings): if _SDAM_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _SDAM_LOGGER, - topologyId=self._topology_id, message=_SDAMStatusMessage.START_TOPOLOGY, + topologyId=self._topology_id, ) if self._publish_tp: @@ -150,10 +152,10 @@ def __init__(self, topology_settings: TopologySettings): if _SDAM_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _SDAM_LOGGER, + message=_SDAMStatusMessage.TOPOLOGY_CHANGE, topologyId=self._topology_id, previousDescription=initial_td, newDescription=self._description, - message=_SDAMStatusMessage.TOPOLOGY_CHANGE, ) for seed in topology_settings.seeds: @@ -163,10 +165,10 @@ def __init__(self, topology_settings: TopologySettings): if _SDAM_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _SDAM_LOGGER, + message=_SDAMStatusMessage.START_SERVER, topologyId=self._topology_id, serverHost=seed[0], serverPort=seed[1], - message=_SDAMStatusMessage.START_SERVER, ) # Store the seed list to help diagnose errors in _error_message(). @@ -207,6 +209,9 @@ def target() -> bool: if self._settings.fqdn is not None and not self._settings.load_balanced: self._srv_monitor = SrvMonitor(self, self._settings) + # Stores all monitor tasks that need to be joined on close or server selection + self._monitor_tasks: list[MonitorBase] = [] + def open(self) -> None: """Start monitoring, or restart after a fork. @@ -232,9 +237,7 @@ def open(self) -> None: warnings.warn( # type: ignore[call-overload] # noqa: B028 "MongoClient opened before fork. May not be entirely fork-safe, " "proceed with caution. See PyMongo's documentation for details: " - "https://www.mongodb.com/docs/languages/" - "python/pymongo-driver/current/faq/" - "#is-pymongo-fork-safe-", + "https://dochub.mongodb.org/core/pymongo-fork-deadlock", **kwargs, ) with self._lock: @@ -283,6 +286,10 @@ def select_servers( else: server_timeout = server_selection_timeout + # Cleanup any completed monitor tasks safely + if not _IS_SYNC and self._monitor_tasks: + self.cleanup_monitors() + with self._lock: server_descriptions = self._select_servers_loop( selector, server_timeout, operation, operation_id, address @@ -347,7 +354,7 @@ def _select_servers_loop( operationId=operation_id, topologyDescription=self.description, clientId=self.description._topology_settings._topology_id, - remainingTimeMS=int(end_time - time.monotonic()), + remainingTimeMS=int(1000 * (end_time - time.monotonic())), ) logged_waiting = True @@ -493,7 +500,6 @@ def _process_change( self._description = new_td self._update_servers() - self._receive_cluster_time_no_lock(server_description.cluster_time) if self._publish_tp and not suppress_event: assert self._events is not None @@ -506,10 +512,10 @@ def _process_change( if _SDAM_LOGGER.isEnabledFor(logging.DEBUG) and not suppress_event: _debug_log( _SDAM_LOGGER, + message=_SDAMStatusMessage.TOPOLOGY_CHANGE, topologyId=self._topology_id, previousDescription=td_old, newDescription=self._description, - message=_SDAMStatusMessage.TOPOLOGY_CHANGE, ) # Shutdown SRV polling for unsupported cluster types. @@ -520,12 +526,8 @@ def _process_change( and self._description.topology_type not in SRV_POLLING_TOPOLOGIES ): self._srv_monitor.close() - - # Clear the pool from a failed heartbeat. - if reset_pool: - server = self._servers.get(server_description.address) - if server: - server.pool.reset(interrupt_connections=interrupt_connections) + if not _IS_SYNC: + self._monitor_tasks.append(self._srv_monitor) # Wake anything waiting in select_servers(). self._condition.notify_all() @@ -549,6 +551,11 @@ def on_change( # that didn't include this server. if self._opened and self._description.has_server(server_description.address): self._process_change(server_description, reset_pool, interrupt_connections) + # Clear the pool from a failed heartbeat, done outside the lock to avoid blocking on connection close. + if reset_pool: + server = self._servers.get(server_description.address) + if server: + server.pool.reset(interrupt_connections=interrupt_connections) def _process_srv_update(self, seedlist: list[tuple[str, Any]]) -> None: """Process a new seedlist on an opened topology. @@ -572,10 +579,10 @@ def _process_srv_update(self, seedlist: list[tuple[str, Any]]) -> None: if _SDAM_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _SDAM_LOGGER, + message=_SDAMStatusMessage.TOPOLOGY_CHANGE, topologyId=self._topology_id, previousDescription=td_old, newDescription=self._description, - message=_SDAMStatusMessage.TOPOLOGY_CHANGE, ) def on_srv_update(self, seedlist: list[tuple[str, Any]]) -> None: @@ -693,6 +700,8 @@ def close(self) -> None: old_td = self._description for server in self._servers.values(): server.close() + if not _IS_SYNC: + self._monitor_tasks.append(server._monitor) # Mark all servers Unknown. self._description = self._description.reset() @@ -703,6 +712,8 @@ def close(self) -> None: # Stop SRV polling thread. if self._srv_monitor: self._srv_monitor.close() + if not _IS_SYNC: + self._monitor_tasks.append(self._srv_monitor) self._opened = False self._closed = True @@ -732,13 +743,13 @@ def close(self) -> None: if _SDAM_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _SDAM_LOGGER, + message=_SDAMStatusMessage.TOPOLOGY_CHANGE, topologyId=self._topology_id, previousDescription=old_td, newDescription=self._description, - message=_SDAMStatusMessage.TOPOLOGY_CHANGE, ) _debug_log( - _SDAM_LOGGER, topologyId=self._topology_id, message=_SDAMStatusMessage.STOP_TOPOLOGY + _SDAM_LOGGER, message=_SDAMStatusMessage.STOP_TOPOLOGY, topologyId=self._topology_id ) if self._publish_server or self._publish_tp: @@ -877,6 +888,8 @@ def _handle_error(self, address: _Address, err_ctx: _ErrorContext) -> None: # Clear the pool. server.reset(service_id) elif isinstance(error, ConnectionFailure): + if isinstance(error, WaitQueueTimeoutError): + return # "Client MUST replace the server's description with type Unknown # ... MUST NOT request an immediate check of the server." if not self._settings.load_balanced: @@ -942,6 +955,8 @@ def _update_servers(self) -> None: for address, server in list(self._servers.items()): if not self._description.has_server(address): server.close() + if not _IS_SYNC: + self._monitor_tasks.append(server._monitor) self._servers.pop(address) def _create_pool_for_server(self, address: _Address) -> Pool: @@ -1029,6 +1044,15 @@ def _error_message(self, selector: Callable[[Selection], Selection]) -> str: else: return ",".join(str(server.error) for server in servers if server.error) + def cleanup_monitors(self) -> None: + tasks = [] + try: + while self._monitor_tasks: + tasks.append(self._monitor_tasks.pop()) + except IndexError: + pass + asyncio.gather(*[t.join() for t in tasks], return_exceptions=True) # type: ignore[func-returns-value] + def __repr__(self) -> str: msg = "" if not self._opened: diff --git a/pymongo/synchronous/uri_parser.py b/pymongo/synchronous/uri_parser.py new file mode 100644 index 0000000000..52b59b8fe8 --- /dev/null +++ b/pymongo/synchronous/uri_parser.py @@ -0,0 +1,188 @@ +# Copyright 2011-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +"""Tools to parse and validate a MongoDB URI.""" +from __future__ import annotations + +from typing import Any, Optional +from urllib.parse import unquote_plus + +from pymongo.common import SRV_SERVICE_NAME, _CaseInsensitiveDictionary +from pymongo.errors import ConfigurationError, InvalidURI +from pymongo.synchronous.srv_resolver import _SrvResolver +from pymongo.uri_parser_shared import ( + _ALLOWED_TXT_OPTS, + DEFAULT_PORT, + SCHEME, + SCHEME_LEN, + SRV_SCHEME_LEN, + _check_options, + _validate_uri, + split_hosts, + split_options, +) + +_IS_SYNC = True + + +def parse_uri( + uri: str, + default_port: Optional[int] = DEFAULT_PORT, + validate: bool = True, + warn: bool = False, + normalize: bool = True, + connect_timeout: Optional[float] = None, + srv_service_name: Optional[str] = None, + srv_max_hosts: Optional[int] = None, +) -> dict[str, Any]: + """Parse and validate a MongoDB URI. + + Returns a dict of the form:: + + { + 'nodelist': , + 'username': or None, + 'password': or None, + 'database': or None, + 'collection': or None, + 'options': , + 'fqdn': or None + } + + If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done + to build nodelist and options. + + :param uri: The MongoDB URI to parse. + :param default_port: The port number to use when one wasn't specified + for a host in the URI. + :param validate: If ``True`` (the default), validate and + normalize all options. Default: ``True``. + :param warn: When validating, if ``True`` then will warn + the user then ignore any invalid options or values. If ``False``, + validation will error when options are unsupported or values are + invalid. Default: ``False``. + :param normalize: If ``True``, convert names of URI options + to their internally-used names. Default: ``True``. + :param connect_timeout: The maximum time in milliseconds to + wait for a response from the DNS server. + :param srv_service_name: A custom SRV service name + + .. versionchanged:: 4.6 + The delimiting slash (``/``) between hosts and connection options is now optional. + For example, "mongodb://example.com?tls=true" is now a valid URI. + + .. versionchanged:: 4.0 + To better follow RFC 3986, unquoted percent signs ("%") are no longer + supported. + + .. versionchanged:: 3.9 + Added the ``normalize`` parameter. + + .. versionchanged:: 3.6 + Added support for mongodb+srv:// URIs. + + .. versionchanged:: 3.5 + Return the original value of the ``readPreference`` MongoDB URI option + instead of the validated read preference mode. + + .. versionchanged:: 3.1 + ``warn`` added so invalid options can be ignored. + """ + result = _validate_uri(uri, default_port, validate, warn, normalize, srv_max_hosts) + result.update( + _parse_srv( + uri, + default_port, + validate, + warn, + normalize, + connect_timeout, + srv_service_name, + srv_max_hosts, + ) + ) + return result + + +def _parse_srv( + uri: str, + default_port: Optional[int] = DEFAULT_PORT, + validate: bool = True, + warn: bool = False, + normalize: bool = True, + connect_timeout: Optional[float] = None, + srv_service_name: Optional[str] = None, + srv_max_hosts: Optional[int] = None, +) -> dict[str, Any]: + if uri.startswith(SCHEME): + is_srv = False + scheme_free = uri[SCHEME_LEN:] + else: + is_srv = True + scheme_free = uri[SRV_SCHEME_LEN:] + + options = _CaseInsensitiveDictionary() + + host_plus_db_part, _, opts = scheme_free.partition("?") + if "/" in host_plus_db_part: + host_part, _, _ = host_plus_db_part.partition("/") + else: + host_part = host_plus_db_part + + if opts: + options.update(split_options(opts, validate, warn, normalize)) + if srv_service_name is None: + srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME) + if "@" in host_part: + _, _, hosts = host_part.rpartition("@") + else: + hosts = host_part + + hosts = unquote_plus(hosts) + srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts") + if is_srv: + nodes = split_hosts(hosts, default_port=None) + fqdn, port = nodes[0] + + # Use the connection timeout. connectTimeoutMS passed as a keyword + # argument overrides the same option passed in the connection string. + connect_timeout = connect_timeout or options.get("connectTimeoutMS") + dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts) + nodes = dns_resolver.get_hosts() + dns_options = dns_resolver.get_options() + if dns_options: + parsed_dns_options = split_options(dns_options, validate, warn, normalize) + if set(parsed_dns_options) - _ALLOWED_TXT_OPTS: + raise ConfigurationError( + "Only authSource, replicaSet, and loadBalanced are supported from DNS" + ) + for opt, val in parsed_dns_options.items(): + if opt not in options: + options[opt] = val + if options.get("loadBalanced") and srv_max_hosts: + raise InvalidURI("You cannot specify loadBalanced with srvMaxHosts") + if options.get("replicaSet") and srv_max_hosts: + raise InvalidURI("You cannot specify replicaSet with srvMaxHosts") + if "tls" not in options and "ssl" not in options: + options["tls"] = True if validate else "true" + else: + nodes = split_hosts(hosts, default_port=default_port) + + _check_options(nodes, options) + + return { + "nodelist": nodes, + "options": options, + } diff --git a/pymongo/topology_description.py b/pymongo/topology_description.py index f669fefd2e..29293b2314 100644 --- a/pymongo/topology_description.py +++ b/pymongo/topology_description.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -33,7 +33,7 @@ from bson.min_key import MinKey from bson.objectid import ObjectId from pymongo import common -from pymongo.errors import ConfigurationError +from pymongo.errors import ConfigurationError, PyMongoError from pymongo.read_preferences import ReadPreference, _AggWritePref, _ServerMode from pymongo.server_description import ServerDescription from pymongo.server_selectors import Selection @@ -563,7 +563,11 @@ def _update_rs_from_primary( if None not in new_election_tuple: if None not in max_election_tuple and new_election_tuple < max_election_tuple: # Stale primary, set to type Unknown. - sds[server_description.address] = server_description.to_unknown() + sds[server_description.address] = server_description.to_unknown( + PyMongoError( + f"primary marked stale due to electionId/setVersion mismatch, {new_election_tuple} is stale compared to {max_election_tuple}" + ) + ) return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id max_election_id = server_description.election_id @@ -578,7 +582,11 @@ def _update_rs_from_primary( max_election_safe = tuple(MinKey() if i is None else i for i in max_election_tuple) if new_election_safe < max_election_safe: # Stale primary, set to type Unknown. - sds[server_description.address] = server_description.to_unknown() + sds[server_description.address] = server_description.to_unknown( + PyMongoError( + f"primary marked stale due to electionId/setVersion mismatch, {new_election_tuple} is stale compared to {max_election_tuple}" + ) + ) return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id else: max_election_id = server_description.election_id @@ -591,7 +599,9 @@ def _update_rs_from_primary( and server.address != server_description.address ): # Reset old primary's type to Unknown. - sds[server.address] = server.to_unknown() + sds[server.address] = server.to_unknown( + PyMongoError("primary marked stale due to discovery of newer primary") + ) # There can be only one prior primary. break diff --git a/pymongo/typings.py b/pymongo/typings.py index 68962eb540..ce6f369d1f 100644 --- a/pymongo/typings.py +++ b/pymongo/typings.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pymongo/uri_parser.py b/pymongo/uri_parser.py index 7018dad7d8..fe253b9bbf 100644 --- a/pymongo/uri_parser.py +++ b/pymongo/uri_parser.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -13,627 +13,32 @@ # permissions and limitations under the License. -"""Tools to parse and validate a MongoDB URI. - -.. seealso:: This module is compatible with both the synchronous and asynchronous PyMongo APIs. -""" +"""Re-import of synchronous URI Parser API for compatibility.""" from __future__ import annotations -import re import sys -import warnings -from typing import ( - TYPE_CHECKING, - Any, - Mapping, - MutableMapping, - Optional, - Sized, - Union, - cast, -) -from urllib.parse import unquote_plus - -from pymongo.client_options import _parse_ssl_options -from pymongo.common import ( - INTERNAL_URI_OPTION_NAME_MAP, - SRV_SERVICE_NAME, - URI_OPTIONS_DEPRECATION_MAP, - _CaseInsensitiveDictionary, - get_validated_options, -) -from pymongo.errors import ConfigurationError, InvalidURI -from pymongo.srv_resolver import _have_dnspython, _SrvResolver -from pymongo.typings import _Address - -if TYPE_CHECKING: - from pymongo.pyopenssl_context import SSLContext - -SCHEME = "mongodb://" -SCHEME_LEN = len(SCHEME) -SRV_SCHEME = "mongodb+srv://" -SRV_SCHEME_LEN = len(SRV_SCHEME) -DEFAULT_PORT = 27017 - - -def _unquoted_percent(s: str) -> bool: - """Check for unescaped percent signs. - - :param s: A string. `s` can have things like '%25', '%2525', - and '%E2%85%A8' but cannot have unquoted percent like '%foo'. - """ - for i in range(len(s)): - if s[i] == "%": - sub = s[i : i + 3] - # If unquoting yields the same string this means there was an - # unquoted %. - if unquote_plus(sub) == sub: - return True - return False - - -def parse_userinfo(userinfo: str) -> tuple[str, str]: - """Validates the format of user information in a MongoDB URI. - Reserved characters that are gen-delimiters (":", "/", "?", "#", "[", - "]", "@") as per RFC 3986 must be escaped. - - Returns a 2-tuple containing the unescaped username followed - by the unescaped password. - - :param userinfo: A string of the form : - """ - if "@" in userinfo or userinfo.count(":") > 1 or _unquoted_percent(userinfo): - raise InvalidURI( - "Username and password must be escaped according to " - "RFC 3986, use urllib.parse.quote_plus" - ) - - user, _, passwd = userinfo.partition(":") - # No password is expected with GSSAPI authentication. - if not user: - raise InvalidURI("The empty string is not valid username.") - - return unquote_plus(user), unquote_plus(passwd) - - -def parse_ipv6_literal_host( - entity: str, default_port: Optional[int] -) -> tuple[str, Optional[Union[str, int]]]: - """Validates an IPv6 literal host:port string. - - Returns a 2-tuple of IPv6 literal followed by port where - port is default_port if it wasn't specified in entity. - - :param entity: A string that represents an IPv6 literal enclosed - in braces (e.g. '[::1]' or '[::1]:27017'). - :param default_port: The port number to use when one wasn't - specified in entity. - """ - if entity.find("]") == -1: - raise ValueError( - "an IPv6 address literal must be enclosed in '[' and ']' according to RFC 2732." - ) - i = entity.find("]:") - if i == -1: - return entity[1:-1], default_port - return entity[1:i], entity[i + 2 :] - - -def parse_host(entity: str, default_port: Optional[int] = DEFAULT_PORT) -> _Address: - """Validates a host string - - Returns a 2-tuple of host followed by port where port is default_port - if it wasn't specified in the string. - - :param entity: A host or host:port string where host could be a - hostname or IP address. - :param default_port: The port number to use when one wasn't - specified in entity. - """ - host = entity - port: Optional[Union[str, int]] = default_port - if entity[0] == "[": - host, port = parse_ipv6_literal_host(entity, default_port) - elif entity.endswith(".sock"): - return entity, default_port - elif entity.find(":") != -1: - if entity.count(":") > 1: - raise ValueError( - "Reserved characters such as ':' must be " - "escaped according RFC 2396. An IPv6 " - "address literal must be enclosed in '[' " - "and ']' according to RFC 2732." - ) - host, port = host.split(":", 1) - if isinstance(port, str): - if not port.isdigit(): - # Special case check for mistakes like "mongodb://localhost:27017 ". - if all(c.isspace() or c.isdigit() for c in port): - for c in port: - if c.isspace(): - raise ValueError(f"Port contains whitespace character: {c!r}") - - # A non-digit port indicates that the URI is invalid, likely because the password - # or username were not escaped. - raise ValueError( - "Port contains non-digit characters. Hint: username and password must be escaped according to " - "RFC 3986, use urllib.parse.quote_plus" - ) - if int(port) > 65535 or int(port) <= 0: - raise ValueError("Port must be an integer between 0 and 65535") - port = int(port) - - # Normalize hostname to lowercase, since DNS is case-insensitive: - # http://tools.ietf.org/html/rfc4343 - # This prevents useless rediscovery if "foo.com" is in the seed list but - # "FOO.com" is in the hello response. - return host.lower(), port - - -# Options whose values are implicitly determined by tlsInsecure. -_IMPLICIT_TLSINSECURE_OPTS = { - "tlsallowinvalidcertificates", - "tlsallowinvalidhostnames", - "tlsdisableocspendpointcheck", -} - - -def _parse_options(opts: str, delim: Optional[str]) -> _CaseInsensitiveDictionary: - """Helper method for split_options which creates the options dict. - Also handles the creation of a list for the URI tag_sets/ - readpreferencetags portion, and the use of a unicode options string. - """ - options = _CaseInsensitiveDictionary() - for uriopt in opts.split(delim): - key, value = uriopt.split("=") - if key.lower() == "readpreferencetags": - options.setdefault(key, []).append(value) - else: - if key in options: - warnings.warn(f"Duplicate URI option '{key}'.", stacklevel=2) - if key.lower() == "authmechanismproperties": - val = value - else: - val = unquote_plus(value) - options[key] = val - - return options - - -def _handle_security_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: - """Raise appropriate errors when conflicting TLS options are present in - the options dictionary. - - :param options: Instance of _CaseInsensitiveDictionary containing - MongoDB URI options. - """ - # Implicitly defined options must not be explicitly specified. - tlsinsecure = options.get("tlsinsecure") - if tlsinsecure is not None: - for opt in _IMPLICIT_TLSINSECURE_OPTS: - if opt in options: - err_msg = "URI options %s and %s cannot be specified simultaneously." - raise InvalidURI( - err_msg % (options.cased_key("tlsinsecure"), options.cased_key(opt)) - ) - - # Handle co-occurence of OCSP & tlsAllowInvalidCertificates options. - tlsallowinvalidcerts = options.get("tlsallowinvalidcertificates") - if tlsallowinvalidcerts is not None: - if "tlsdisableocspendpointcheck" in options: - err_msg = "URI options %s and %s cannot be specified simultaneously." - raise InvalidURI( - err_msg - % ("tlsallowinvalidcertificates", options.cased_key("tlsdisableocspendpointcheck")) - ) - if tlsallowinvalidcerts is True: - options["tlsdisableocspendpointcheck"] = True - - # Handle co-occurence of CRL and OCSP-related options. - tlscrlfile = options.get("tlscrlfile") - if tlscrlfile is not None: - for opt in ("tlsinsecure", "tlsallowinvalidcertificates", "tlsdisableocspendpointcheck"): - if options.get(opt) is True: - err_msg = "URI option %s=True cannot be specified when CRL checking is enabled." - raise InvalidURI(err_msg % (opt,)) - - if "ssl" in options and "tls" in options: - - def truth_value(val: Any) -> Any: - if val in ("true", "false"): - return val == "true" - if isinstance(val, bool): - return val - return val - - if truth_value(options.get("ssl")) != truth_value(options.get("tls")): - err_msg = "Can not specify conflicting values for URI options %s and %s." - raise InvalidURI(err_msg % (options.cased_key("ssl"), options.cased_key("tls"))) - - return options - - -def _handle_option_deprecations(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: - """Issue appropriate warnings when deprecated options are present in the - options dictionary. Removes deprecated option key, value pairs if the - options dictionary is found to also have the renamed option. - - :param options: Instance of _CaseInsensitiveDictionary containing - MongoDB URI options. - """ - for optname in list(options): - if optname in URI_OPTIONS_DEPRECATION_MAP: - mode, message = URI_OPTIONS_DEPRECATION_MAP[optname] - if mode == "renamed": - newoptname = message - if newoptname in options: - warn_msg = "Deprecated option '%s' ignored in favor of '%s'." - warnings.warn( - warn_msg % (options.cased_key(optname), options.cased_key(newoptname)), - DeprecationWarning, - stacklevel=2, - ) - options.pop(optname) - continue - warn_msg = "Option '%s' is deprecated, use '%s' instead." - warnings.warn( - warn_msg % (options.cased_key(optname), newoptname), - DeprecationWarning, - stacklevel=2, - ) - elif mode == "removed": - warn_msg = "Option '%s' is deprecated. %s." - warnings.warn( - warn_msg % (options.cased_key(optname), message), - DeprecationWarning, - stacklevel=2, - ) - - return options - - -def _normalize_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: - """Normalizes option names in the options dictionary by converting them to - their internally-used names. - - :param options: Instance of _CaseInsensitiveDictionary containing - MongoDB URI options. - """ - # Expand the tlsInsecure option. - tlsinsecure = options.get("tlsinsecure") - if tlsinsecure is not None: - for opt in _IMPLICIT_TLSINSECURE_OPTS: - # Implicit options are logically the same as tlsInsecure. - options[opt] = tlsinsecure - - for optname in list(options): - intname = INTERNAL_URI_OPTION_NAME_MAP.get(optname, None) - if intname is not None: - options[intname] = options.pop(optname) - - return options - - -def validate_options(opts: Mapping[str, Any], warn: bool = False) -> MutableMapping[str, Any]: - """Validates and normalizes options passed in a MongoDB URI. - - Returns a new dictionary of validated and normalized options. If warn is - False then errors will be thrown for invalid options, otherwise they will - be ignored and a warning will be issued. - - :param opts: A dict of MongoDB URI options. - :param warn: If ``True`` then warnings will be logged and - invalid options will be ignored. Otherwise invalid options will - cause errors. - """ - return get_validated_options(opts, warn) - - -def split_options( - opts: str, validate: bool = True, warn: bool = False, normalize: bool = True -) -> MutableMapping[str, Any]: - """Takes the options portion of a MongoDB URI, validates each option - and returns the options in a dictionary. - - :param opt: A string representing MongoDB URI options. - :param validate: If ``True`` (the default), validate and normalize all - options. - :param warn: If ``False`` (the default), suppress all warnings raised - during validation of options. - :param normalize: If ``True`` (the default), renames all options to their - internally-used names. - """ - and_idx = opts.find("&") - semi_idx = opts.find(";") - try: - if and_idx >= 0 and semi_idx >= 0: - raise InvalidURI("Can not mix '&' and ';' for option separators.") - elif and_idx >= 0: - options = _parse_options(opts, "&") - elif semi_idx >= 0: - options = _parse_options(opts, ";") - elif opts.find("=") != -1: - options = _parse_options(opts, None) - else: - raise ValueError - except ValueError: - raise InvalidURI("MongoDB URI options are key=value pairs.") from None - - options = _handle_security_options(options) - - options = _handle_option_deprecations(options) - - if normalize: - options = _normalize_options(options) - - if validate: - options = cast(_CaseInsensitiveDictionary, validate_options(options, warn)) - if options.get("authsource") == "": - raise InvalidURI("the authSource database cannot be an empty string") - - return options - - -def split_hosts(hosts: str, default_port: Optional[int] = DEFAULT_PORT) -> list[_Address]: - """Takes a string of the form host1[:port],host2[:port]... and - splits it into (host, port) tuples. If [:port] isn't present the - default_port is used. - - Returns a set of 2-tuples containing the host name (or IP) followed by - port number. - - :param hosts: A string of the form host1[:port],host2[:port],... - :param default_port: The port number to use when one wasn't specified - for a host. - """ - nodes = [] - for entity in hosts.split(","): - if not entity: - raise ConfigurationError("Empty host (or extra comma in host list).") - port = default_port - # Unix socket entities don't have ports - if entity.endswith(".sock"): - port = None - nodes.append(parse_host(entity, port)) - return nodes - - -# Prohibited characters in database name. DB names also can't have ".", but for -# backward-compat we allow "db.collection" in URI. -_BAD_DB_CHARS = re.compile("[" + re.escape(r'/ "$') + "]") - -_ALLOWED_TXT_OPTS = frozenset( - ["authsource", "authSource", "replicaset", "replicaSet", "loadbalanced", "loadBalanced"] -) - - -def _check_options(nodes: Sized, options: Mapping[str, Any]) -> None: - # Ensure directConnection was not True if there are multiple seeds. - if len(nodes) > 1 and options.get("directconnection"): - raise ConfigurationError("Cannot specify multiple hosts with directConnection=true") - - if options.get("loadbalanced"): - if len(nodes) > 1: - raise ConfigurationError("Cannot specify multiple hosts with loadBalanced=true") - if options.get("directconnection"): - raise ConfigurationError("Cannot specify directConnection=true with loadBalanced=true") - if options.get("replicaset"): - raise ConfigurationError("Cannot specify replicaSet with loadBalanced=true") - - -def parse_uri( - uri: str, - default_port: Optional[int] = DEFAULT_PORT, - validate: bool = True, - warn: bool = False, - normalize: bool = True, - connect_timeout: Optional[float] = None, - srv_service_name: Optional[str] = None, - srv_max_hosts: Optional[int] = None, -) -> dict[str, Any]: - """Parse and validate a MongoDB URI. - - Returns a dict of the form:: - - { - 'nodelist': , - 'username': or None, - 'password': or None, - 'database': or None, - 'collection': or None, - 'options': , - 'fqdn': or None - } - - If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done - to build nodelist and options. - - :param uri: The MongoDB URI to parse. - :param default_port: The port number to use when one wasn't specified - for a host in the URI. - :param validate: If ``True`` (the default), validate and - normalize all options. Default: ``True``. - :param warn: When validating, if ``True`` then will warn - the user then ignore any invalid options or values. If ``False``, - validation will error when options are unsupported or values are - invalid. Default: ``False``. - :param normalize: If ``True``, convert names of URI options - to their internally-used names. Default: ``True``. - :param connect_timeout: The maximum time in milliseconds to - wait for a response from the DNS server. - :param srv_service_name: A custom SRV service name - - .. versionchanged:: 4.6 - The delimiting slash (``/``) between hosts and connection options is now optional. - For example, "mongodb://example.com?tls=true" is now a valid URI. - - .. versionchanged:: 4.0 - To better follow RFC 3986, unquoted percent signs ("%") are no longer - supported. - - .. versionchanged:: 3.9 - Added the ``normalize`` parameter. - - .. versionchanged:: 3.6 - Added support for mongodb+srv:// URIs. - - .. versionchanged:: 3.5 - Return the original value of the ``readPreference`` MongoDB URI option - instead of the validated read preference mode. - - .. versionchanged:: 3.1 - ``warn`` added so invalid options can be ignored. - """ - if uri.startswith(SCHEME): - is_srv = False - scheme_free = uri[SCHEME_LEN:] - elif uri.startswith(SRV_SCHEME): - if not _have_dnspython(): - python_path = sys.executable or "python" - raise ConfigurationError( - 'The "dnspython" module must be ' - "installed to use mongodb+srv:// URIs. " - "To fix this error install pymongo again:\n " - "%s -m pip install pymongo>=4.3" % (python_path) - ) - is_srv = True - scheme_free = uri[SRV_SCHEME_LEN:] - else: - raise InvalidURI(f"Invalid URI scheme: URI must begin with '{SCHEME}' or '{SRV_SCHEME}'") - - if not scheme_free: - raise InvalidURI("Must provide at least one hostname or IP.") - - user = None - passwd = None - dbase = None - collection = None - options = _CaseInsensitiveDictionary() - - host_plus_db_part, _, opts = scheme_free.partition("?") - if "/" in host_plus_db_part: - host_part, _, dbase = host_plus_db_part.partition("/") - else: - host_part = host_plus_db_part - - if dbase: - dbase = unquote_plus(dbase) - if "." in dbase: - dbase, collection = dbase.split(".", 1) - if _BAD_DB_CHARS.search(dbase): - raise InvalidURI('Bad database name "%s"' % dbase) - else: - dbase = None - - if opts: - options.update(split_options(opts, validate, warn, normalize)) - if srv_service_name is None: - srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME) - if "@" in host_part: - userinfo, _, hosts = host_part.rpartition("@") - user, passwd = parse_userinfo(userinfo) - else: - hosts = host_part - - if "/" in hosts: - raise InvalidURI("Any '/' in a unix domain socket must be percent-encoded: %s" % host_part) - - hosts = unquote_plus(hosts) - fqdn = None - srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts") - if is_srv: - if options.get("directConnection"): - raise ConfigurationError(f"Cannot specify directConnection=true with {SRV_SCHEME} URIs") - nodes = split_hosts(hosts, default_port=None) - if len(nodes) != 1: - raise InvalidURI(f"{SRV_SCHEME} URIs must include one, and only one, hostname") - fqdn, port = nodes[0] - if port is not None: - raise InvalidURI(f"{SRV_SCHEME} URIs must not include a port number") - - # Use the connection timeout. connectTimeoutMS passed as a keyword - # argument overrides the same option passed in the connection string. - connect_timeout = connect_timeout or options.get("connectTimeoutMS") - dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts) - nodes = dns_resolver.get_hosts() - dns_options = dns_resolver.get_options() - if dns_options: - parsed_dns_options = split_options(dns_options, validate, warn, normalize) - if set(parsed_dns_options) - _ALLOWED_TXT_OPTS: - raise ConfigurationError( - "Only authSource, replicaSet, and loadBalanced are supported from DNS" - ) - for opt, val in parsed_dns_options.items(): - if opt not in options: - options[opt] = val - if options.get("loadBalanced") and srv_max_hosts: - raise InvalidURI("You cannot specify loadBalanced with srvMaxHosts") - if options.get("replicaSet") and srv_max_hosts: - raise InvalidURI("You cannot specify replicaSet with srvMaxHosts") - if "tls" not in options and "ssl" not in options: - options["tls"] = True if validate else "true" - elif not is_srv and options.get("srvServiceName") is not None: - raise ConfigurationError( - "The srvServiceName option is only allowed with 'mongodb+srv://' URIs" - ) - elif not is_srv and srv_max_hosts: - raise ConfigurationError( - "The srvMaxHosts option is only allowed with 'mongodb+srv://' URIs" - ) - else: - nodes = split_hosts(hosts, default_port=default_port) - - _check_options(nodes, options) - - return { - "nodelist": nodes, - "username": user, - "password": passwd, - "database": dbase, - "collection": collection, - "options": options, - "fqdn": fqdn, - } - - -def _parse_kms_tls_options(kms_tls_options: Optional[Mapping[str, Any]]) -> dict[str, SSLContext]: - """Parse KMS TLS connection options.""" - if not kms_tls_options: - return {} - if not isinstance(kms_tls_options, dict): - raise TypeError("kms_tls_options must be a dict") - contexts = {} - for provider, options in kms_tls_options.items(): - if not isinstance(options, dict): - raise TypeError(f'kms_tls_options["{provider}"] must be a dict') - options.setdefault("tls", True) - opts = _CaseInsensitiveDictionary(options) - opts = _handle_security_options(opts) - opts = _normalize_options(opts) - opts = cast(_CaseInsensitiveDictionary, validate_options(opts)) - ssl_context, allow_invalid_hostnames = _parse_ssl_options(opts) - if ssl_context is None: - raise ConfigurationError("TLS is required for KMS providers") - if allow_invalid_hostnames: - raise ConfigurationError("Insecure TLS options prohibited") - - for n in [ - "tlsInsecure", - "tlsAllowInvalidCertificates", - "tlsAllowInvalidHostnames", - "tlsDisableCertificateRevocationCheck", - ]: - if n in opts: - raise ConfigurationError(f"Insecure TLS options prohibited: {n}") - contexts[provider] = ssl_context - return contexts +from pymongo.errors import InvalidURI +from pymongo.synchronous.uri_parser import * # noqa: F403 +from pymongo.synchronous.uri_parser import __doc__ as original_doc +from pymongo.uri_parser_shared import * # noqa: F403 + +__doc__ = original_doc +__all__ = [ # noqa: F405 + "parse_userinfo", + "parse_ipv6_literal_host", + "parse_host", + "validate_options", + "split_options", + "split_hosts", + "parse_uri", +] if __name__ == "__main__": import pprint try: - pprint.pprint(parse_uri(sys.argv[1])) # noqa: T203 + pprint.pprint(parse_uri(sys.argv[1])) # noqa: F405, T203 except InvalidURI as exc: print(exc) # noqa: T201 sys.exit(0) diff --git a/pymongo/uri_parser_shared.py b/pymongo/uri_parser_shared.py new file mode 100644 index 0000000000..0cef176bf1 --- /dev/null +++ b/pymongo/uri_parser_shared.py @@ -0,0 +1,552 @@ +# Copyright 2011-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +"""Tools to parse and validate a MongoDB URI. + +.. seealso:: This module is compatible with both the synchronous and asynchronous PyMongo APIs. +""" +from __future__ import annotations + +import re +import sys +import warnings +from typing import ( + TYPE_CHECKING, + Any, + Mapping, + MutableMapping, + Optional, + Sized, + Union, + cast, +) +from urllib.parse import unquote_plus + +from pymongo.asynchronous.srv_resolver import _have_dnspython +from pymongo.client_options import _parse_ssl_options +from pymongo.common import ( + INTERNAL_URI_OPTION_NAME_MAP, + URI_OPTIONS_DEPRECATION_MAP, + _CaseInsensitiveDictionary, + get_validated_options, +) +from pymongo.errors import ConfigurationError, InvalidURI +from pymongo.typings import _Address + +if TYPE_CHECKING: + from pymongo.pyopenssl_context import SSLContext + +SCHEME = "mongodb://" +SCHEME_LEN = len(SCHEME) +SRV_SCHEME = "mongodb+srv://" +SRV_SCHEME_LEN = len(SRV_SCHEME) +DEFAULT_PORT = 27017 + + +def _unquoted_percent(s: str) -> bool: + """Check for unescaped percent signs. + + :param s: A string. `s` can have things like '%25', '%2525', + and '%E2%85%A8' but cannot have unquoted percent like '%foo'. + """ + for i in range(len(s)): + if s[i] == "%": + sub = s[i : i + 3] + # If unquoting yields the same string this means there was an + # unquoted %. + if unquote_plus(sub) == sub: + return True + return False + + +def parse_userinfo(userinfo: str) -> tuple[str, str]: + """Validates the format of user information in a MongoDB URI. + Reserved characters that are gen-delimiters (":", "/", "?", "#", "[", + "]", "@") as per RFC 3986 must be escaped. + + Returns a 2-tuple containing the unescaped username followed + by the unescaped password. + + :param userinfo: A string of the form : + """ + if "@" in userinfo or userinfo.count(":") > 1 or _unquoted_percent(userinfo): + raise InvalidURI( + "Username and password must be escaped according to " + "RFC 3986, use urllib.parse.quote_plus" + ) + + user, _, passwd = userinfo.partition(":") + # No password is expected with GSSAPI authentication. + if not user: + raise InvalidURI("The empty string is not valid username") + + return unquote_plus(user), unquote_plus(passwd) + + +def parse_ipv6_literal_host( + entity: str, default_port: Optional[int] +) -> tuple[str, Optional[Union[str, int]]]: + """Validates an IPv6 literal host:port string. + + Returns a 2-tuple of IPv6 literal followed by port where + port is default_port if it wasn't specified in entity. + + :param entity: A string that represents an IPv6 literal enclosed + in braces (e.g. '[::1]' or '[::1]:27017'). + :param default_port: The port number to use when one wasn't + specified in entity. + """ + if entity.find("]") == -1: + raise ValueError( + "an IPv6 address literal must be enclosed in '[' and ']' according to RFC 2732." + ) + i = entity.find("]:") + if i == -1: + return entity[1:-1], default_port + return entity[1:i], entity[i + 2 :] + + +def parse_host(entity: str, default_port: Optional[int] = DEFAULT_PORT) -> _Address: + """Validates a host string + + Returns a 2-tuple of host followed by port where port is default_port + if it wasn't specified in the string. + + :param entity: A host or host:port string where host could be a + hostname or IP address. + :param default_port: The port number to use when one wasn't + specified in entity. + """ + host = entity + port: Optional[Union[str, int]] = default_port + if entity[0] == "[": + host, port = parse_ipv6_literal_host(entity, default_port) + elif entity.endswith(".sock"): + return entity, default_port + elif entity.find(":") != -1: + if entity.count(":") > 1: + raise ValueError( + "Reserved characters such as ':' must be " + "escaped according RFC 2396. An IPv6 " + "address literal must be enclosed in '[' " + "and ']' according to RFC 2732." + ) + host, port = host.split(":", 1) + if isinstance(port, str): + if not port.isdigit(): + # Special case check for mistakes like "mongodb://localhost:27017 ". + if all(c.isspace() or c.isdigit() for c in port): + for c in port: + if c.isspace(): + raise ValueError(f"Port contains whitespace character: {c!r}") + + # A non-digit port indicates that the URI is invalid, likely because the password + # or username were not escaped. + raise ValueError( + "Port contains non-digit characters. Hint: username and password must be escaped according to " + "RFC 3986, use urllib.parse.quote_plus" + ) + if int(port) > 65535 or int(port) <= 0: + raise ValueError("Port must be an integer between 0 and 65535") + port = int(port) + + # Normalize hostname to lowercase, since DNS is case-insensitive: + # https://tools.ietf.org/html/rfc4343 + # This prevents useless rediscovery if "foo.com" is in the seed list but + # "FOO.com" is in the hello response. + return host.lower(), port + + +# Options whose values are implicitly determined by tlsInsecure. +_IMPLICIT_TLSINSECURE_OPTS = { + "tlsallowinvalidcertificates", + "tlsallowinvalidhostnames", + "tlsdisableocspendpointcheck", +} + + +def _parse_options(opts: str, delim: Optional[str]) -> _CaseInsensitiveDictionary: + """Helper method for split_options which creates the options dict. + Also handles the creation of a list for the URI tag_sets/ + readpreferencetags portion, and the use of a unicode options string. + """ + options = _CaseInsensitiveDictionary() + for uriopt in opts.split(delim): + key, value = uriopt.split("=") + if key.lower() == "readpreferencetags": + options.setdefault(key, []).append(value) + else: + if key in options: + warnings.warn(f"Duplicate URI option '{key}'.", stacklevel=2) + if key.lower() == "authmechanismproperties": + val = value + else: + val = unquote_plus(value) + options[key] = val + + return options + + +def _handle_security_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: + """Raise appropriate errors when conflicting TLS options are present in + the options dictionary. + + :param options: Instance of _CaseInsensitiveDictionary containing + MongoDB URI options. + """ + # Implicitly defined options must not be explicitly specified. + tlsinsecure = options.get("tlsinsecure") + if tlsinsecure is not None: + for opt in _IMPLICIT_TLSINSECURE_OPTS: + if opt in options: + err_msg = "URI options %s and %s cannot be specified simultaneously." + raise InvalidURI( + err_msg % (options.cased_key("tlsinsecure"), options.cased_key(opt)) + ) + + # Handle co-occurence of OCSP & tlsAllowInvalidCertificates options. + tlsallowinvalidcerts = options.get("tlsallowinvalidcertificates") + if tlsallowinvalidcerts is not None: + if "tlsdisableocspendpointcheck" in options: + err_msg = "URI options %s and %s cannot be specified simultaneously." + raise InvalidURI( + err_msg + % ("tlsallowinvalidcertificates", options.cased_key("tlsdisableocspendpointcheck")) + ) + if tlsallowinvalidcerts is True: + options["tlsdisableocspendpointcheck"] = True + + # Handle co-occurence of CRL and OCSP-related options. + tlscrlfile = options.get("tlscrlfile") + if tlscrlfile is not None: + for opt in ("tlsinsecure", "tlsallowinvalidcertificates", "tlsdisableocspendpointcheck"): + if options.get(opt) is True: + err_msg = "URI option %s=True cannot be specified when CRL checking is enabled." + raise InvalidURI(err_msg % (opt,)) + + if "ssl" in options and "tls" in options: + + def truth_value(val: Any) -> Any: + if val in ("true", "false"): + return val == "true" + if isinstance(val, bool): + return val + return val + + if truth_value(options.get("ssl")) != truth_value(options.get("tls")): + err_msg = "Can not specify conflicting values for URI options %s and %s." + raise InvalidURI(err_msg % (options.cased_key("ssl"), options.cased_key("tls"))) + + return options + + +def _handle_option_deprecations(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: + """Issue appropriate warnings when deprecated options are present in the + options dictionary. Removes deprecated option key, value pairs if the + options dictionary is found to also have the renamed option. + + :param options: Instance of _CaseInsensitiveDictionary containing + MongoDB URI options. + """ + for optname in list(options): + if optname in URI_OPTIONS_DEPRECATION_MAP: + mode, message = URI_OPTIONS_DEPRECATION_MAP[optname] + if mode == "renamed": + newoptname = message + if newoptname in options: + warn_msg = "Deprecated option '%s' ignored in favor of '%s'." + warnings.warn( + warn_msg % (options.cased_key(optname), options.cased_key(newoptname)), + DeprecationWarning, + stacklevel=2, + ) + options.pop(optname) + continue + warn_msg = "Option '%s' is deprecated, use '%s' instead." + warnings.warn( + warn_msg % (options.cased_key(optname), newoptname), + DeprecationWarning, + stacklevel=2, + ) + elif mode == "removed": + warn_msg = "Option '%s' is deprecated. %s." + warnings.warn( + warn_msg % (options.cased_key(optname), message), + DeprecationWarning, + stacklevel=2, + ) + + return options + + +def _normalize_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: + """Normalizes option names in the options dictionary by converting them to + their internally-used names. + + :param options: Instance of _CaseInsensitiveDictionary containing + MongoDB URI options. + """ + # Expand the tlsInsecure option. + tlsinsecure = options.get("tlsinsecure") + if tlsinsecure is not None: + for opt in _IMPLICIT_TLSINSECURE_OPTS: + # Implicit options are logically the same as tlsInsecure. + options[opt] = tlsinsecure + + for optname in list(options): + intname = INTERNAL_URI_OPTION_NAME_MAP.get(optname, None) + if intname is not None: + options[intname] = options.pop(optname) + + return options + + +def validate_options(opts: Mapping[str, Any], warn: bool = False) -> MutableMapping[str, Any]: + """Validates and normalizes options passed in a MongoDB URI. + + Returns a new dictionary of validated and normalized options. If warn is + False then errors will be thrown for invalid options, otherwise they will + be ignored and a warning will be issued. + + :param opts: A dict of MongoDB URI options. + :param warn: If ``True`` then warnings will be logged and + invalid options will be ignored. Otherwise invalid options will + cause errors. + """ + return get_validated_options(opts, warn) + + +def split_options( + opts: str, validate: bool = True, warn: bool = False, normalize: bool = True +) -> MutableMapping[str, Any]: + """Takes the options portion of a MongoDB URI, validates each option + and returns the options in a dictionary. + + :param opt: A string representing MongoDB URI options. + :param validate: If ``True`` (the default), validate and normalize all + options. + :param warn: If ``False`` (the default), suppress all warnings raised + during validation of options. + :param normalize: If ``True`` (the default), renames all options to their + internally-used names. + """ + and_idx = opts.find("&") + semi_idx = opts.find(";") + try: + if and_idx >= 0 and semi_idx >= 0: + raise InvalidURI("Can not mix '&' and ';' for option separators") + elif and_idx >= 0: + options = _parse_options(opts, "&") + elif semi_idx >= 0: + options = _parse_options(opts, ";") + elif opts.find("=") != -1: + options = _parse_options(opts, None) + else: + raise ValueError + except ValueError: + raise InvalidURI("MongoDB URI options are key=value pairs") from None + + options = _handle_security_options(options) + + options = _handle_option_deprecations(options) + + if normalize: + options = _normalize_options(options) + + if validate: + options = cast(_CaseInsensitiveDictionary, validate_options(options, warn)) + if options.get("authsource") == "": + raise InvalidURI("the authSource database cannot be an empty string") + + return options + + +def split_hosts(hosts: str, default_port: Optional[int] = DEFAULT_PORT) -> list[_Address]: + """Takes a string of the form host1[:port],host2[:port]... and + splits it into (host, port) tuples. If [:port] isn't present the + default_port is used. + + Returns a set of 2-tuples containing the host name (or IP) followed by + port number. + + :param hosts: A string of the form host1[:port],host2[:port],... + :param default_port: The port number to use when one wasn't specified + for a host. + """ + nodes = [] + for entity in hosts.split(","): + if not entity: + raise ConfigurationError("Empty host (or extra comma in host list)") + port = default_port + # Unix socket entities don't have ports + if entity.endswith(".sock"): + port = None + nodes.append(parse_host(entity, port)) + return nodes + + +# Prohibited characters in database name. DB names also can't have ".", but for +# backward-compat we allow "db.collection" in URI. +_BAD_DB_CHARS = re.compile("[" + re.escape(r'/ "$') + "]") + +_ALLOWED_TXT_OPTS = frozenset( + ["authsource", "authSource", "replicaset", "replicaSet", "loadbalanced", "loadBalanced"] +) + + +def _check_options(nodes: Sized, options: Mapping[str, Any]) -> None: + # Ensure directConnection was not True if there are multiple seeds. + if len(nodes) > 1 and options.get("directconnection"): + raise ConfigurationError("Cannot specify multiple hosts with directConnection=true") + + if options.get("loadbalanced"): + if len(nodes) > 1: + raise ConfigurationError("Cannot specify multiple hosts with loadBalanced=true") + if options.get("directconnection"): + raise ConfigurationError("Cannot specify directConnection=true with loadBalanced=true") + if options.get("replicaset"): + raise ConfigurationError("Cannot specify replicaSet with loadBalanced=true") + + +def _parse_kms_tls_options( + kms_tls_options: Optional[Mapping[str, Any]], + is_sync: bool, +) -> dict[str, SSLContext]: + """Parse KMS TLS connection options.""" + if not kms_tls_options: + return {} + if not isinstance(kms_tls_options, dict): + raise TypeError("kms_tls_options must be a dict") + contexts = {} + for provider, options in kms_tls_options.items(): + if not isinstance(options, dict): + raise TypeError(f'kms_tls_options["{provider}"] must be a dict') + options.setdefault("tls", True) + opts = _CaseInsensitiveDictionary(options) + opts = _handle_security_options(opts) + opts = _normalize_options(opts) + opts = cast(_CaseInsensitiveDictionary, validate_options(opts)) + ssl_context, allow_invalid_hostnames = _parse_ssl_options(opts, is_sync) + if ssl_context is None: + raise ConfigurationError("TLS is required for KMS providers") + if allow_invalid_hostnames: + raise ConfigurationError("Insecure TLS options prohibited") + + for n in [ + "tlsInsecure", + "tlsAllowInvalidCertificates", + "tlsAllowInvalidHostnames", + "tlsDisableCertificateRevocationCheck", + ]: + if n in opts: + raise ConfigurationError(f"Insecure TLS options prohibited: {n}") + contexts[provider] = ssl_context + return contexts + + +def _validate_uri( + uri: str, + default_port: Optional[int] = DEFAULT_PORT, + validate: bool = True, + warn: bool = False, + normalize: bool = True, + srv_max_hosts: Optional[int] = None, +) -> dict[str, Any]: + if uri.startswith(SCHEME): + is_srv = False + scheme_free = uri[SCHEME_LEN:] + elif uri.startswith(SRV_SCHEME): + if not _have_dnspython(): + python_path = sys.executable or "python" + raise ConfigurationError( + 'The "dnspython" module must be ' + "installed to use mongodb+srv:// URIs. " + "To fix this error install pymongo again:\n " + "%s -m pip install pymongo>=4.3" % (python_path) + ) + is_srv = True + scheme_free = uri[SRV_SCHEME_LEN:] + else: + raise InvalidURI(f"Invalid URI scheme: URI must begin with '{SCHEME}' or '{SRV_SCHEME}'") + + if not scheme_free: + raise InvalidURI("Must provide at least one hostname or IP") + + user = None + passwd = None + dbase = None + collection = None + options = _CaseInsensitiveDictionary() + + host_plus_db_part, _, opts = scheme_free.partition("?") + if "/" in host_plus_db_part: + host_part, _, dbase = host_plus_db_part.partition("/") + else: + host_part = host_plus_db_part + + if dbase: + dbase = unquote_plus(dbase) + if "." in dbase: + dbase, collection = dbase.split(".", 1) + if _BAD_DB_CHARS.search(dbase): + raise InvalidURI('Bad database name "%s"' % dbase) + else: + dbase = None + + if opts: + options.update(split_options(opts, validate, warn, normalize)) + if "@" in host_part: + userinfo, _, hosts = host_part.rpartition("@") + user, passwd = parse_userinfo(userinfo) + else: + hosts = host_part + + if "/" in hosts: + raise InvalidURI("Any '/' in a unix domain socket must be percent-encoded: %s" % host_part) + + hosts = unquote_plus(hosts) + fqdn = None + srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts") + if is_srv: + if options.get("directConnection"): + raise ConfigurationError(f"Cannot specify directConnection=true with {SRV_SCHEME} URIs") + nodes = split_hosts(hosts, default_port=None) + if len(nodes) != 1: + raise InvalidURI(f"{SRV_SCHEME} URIs must include one, and only one, hostname") + fqdn, port = nodes[0] + if port is not None: + raise InvalidURI(f"{SRV_SCHEME} URIs must not include a port number") + elif not is_srv and options.get("srvServiceName") is not None: + raise ConfigurationError( + "The srvServiceName option is only allowed with 'mongodb+srv://' URIs" + ) + elif not is_srv and srv_max_hosts: + raise ConfigurationError( + "The srvMaxHosts option is only allowed with 'mongodb+srv://' URIs" + ) + else: + nodes = split_hosts(hosts, default_port=default_port) + + _check_options(nodes, options) + + return { + "nodelist": nodes, + "username": user, + "password": passwd, + "database": dbase, + "collection": collection, + "options": options, + "fqdn": fqdn, + } diff --git a/pymongo/write_concern.py b/pymongo/write_concern.py index 67c9549897..ff31c6730d 100644 --- a/pymongo/write_concern.py +++ b/pymongo/write_concern.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -74,7 +74,7 @@ def __init__( if wtimeout is not None: if not isinstance(wtimeout, int): - raise TypeError("wtimeout must be an integer") + raise TypeError(f"wtimeout must be an integer, not {type(wtimeout)}") if wtimeout < 0: raise ValueError("wtimeout cannot be less than 0") self.__document["wtimeout"] = wtimeout @@ -98,7 +98,7 @@ def __init__( raise ValueError("w cannot be less than 0") self.__acknowledged = w > 0 elif not isinstance(w, str): - raise TypeError("w must be an integer or string") + raise TypeError(f"w must be an integer or string, not {type(w)}") self.__document["w"] = w self.__server_default = not self.__document diff --git a/pyproject.toml b/pyproject.toml index 69249ee4c6..4da75b4c13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ Tracker = "https://jira.mongodb.org/projects/PYTHON/issues" dev = [ "pre-commit>=4.0" ] +pip = ["pip"] gevent = ["gevent"] eventlet = ["eventlet"] coverage = [ @@ -115,21 +116,24 @@ filterwarnings = [ "module:unclosed =1.1.0,<2.0.0 -pymongocrypt>=1.12.0,<2.0.0 +pymongocrypt>=1.13.0,<2.0.0 certifi;os.name=='nt' or sys_platform=='darwin' diff --git a/test/__init__.py b/test/__init__.py index d3a63db2d5..ae5d60a384 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -17,6 +17,7 @@ import asyncio import gc +import inspect import logging import multiprocessing import os @@ -30,30 +31,8 @@ import unittest import warnings from asyncio import iscoroutinefunction -from test.helpers import ( - COMPRESSORS, - IS_SRV, - MONGODB_API_VERSION, - MULTI_MONGOS_LB_URI, - TEST_LOADBALANCER, - TEST_SERVERLESS, - TLS_OPTIONS, - SystemCertsPatcher, - client_knobs, - db_pwd, - db_user, - global_knobs, - host, - is_server_resolvable, - port, - print_running_topology, - print_thread_stacks, - print_thread_tracebacks, - sanitize_cmd, - sanitize_reply, -) -from pymongo.uri_parser import parse_uri +from pymongo.synchronous.uri_parser import parse_uri try: import ipaddress @@ -63,7 +42,6 @@ HAVE_IPADDRESS = False from contextlib import contextmanager from functools import partial, wraps -from test.version import Version from typing import Any, Callable, Dict, Generator, overload from unittest import SkipTest from urllib.parse import quote_plus @@ -78,6 +56,32 @@ from pymongo.synchronous.database import Database from pymongo.synchronous.mongo_client import MongoClient +sys.path[0:0] = [""] + +from test.helpers import ( + COMPRESSORS, + IS_SRV, + MONGODB_API_VERSION, + MULTI_MONGOS_LB_URI, + TEST_LOADBALANCER, + TEST_SERVERLESS, + TLS_OPTIONS, + SystemCertsPatcher, + client_knobs, + db_pwd, + db_user, + global_knobs, + host, + is_server_resolvable, + port, + print_running_topology, + print_thread_stacks, + print_thread_tracebacks, + sanitize_cmd, + sanitize_reply, +) +from test.version import Version + _IS_SYNC = True @@ -589,7 +593,7 @@ def supports_secondary_read_pref(self): if self.has_secondaries: return True if self.is_mongos: - shard = self.client.config.shards.find_one()["host"] # type:ignore[index] + shard = (self.client.config.shards.find_one())["host"] # type:ignore[index] num_members = shard.count(",") + 1 return num_members > 1 return False @@ -674,7 +678,6 @@ def is_topology_type(self, topologies): "single", "replicaset", "sharded", - "sharded-replicaset", "load-balanced", } if unknown: @@ -689,16 +692,6 @@ def is_topology_type(self, topologies): return True if "sharded" in topologies and self.is_mongos: return True - if "sharded-replicaset" in topologies and self.is_mongos: - shards = client_context.client.config.shards.find().to_list() - for shard in shards: - # For a 3-member RS-backed sharded cluster, shard['host'] - # will be 'replicaName/ip1:port1,ip2:port2,ip3:port3' - # Otherwise it will be 'ip1:port1' - host_spec = shard["host"] - if not len(host_spec.split("/")) > 1: - return False - return True return False def require_cluster_type(self, topologies=None): @@ -830,6 +823,14 @@ def require_sync(self, func): lambda: _IS_SYNC, "This test only works with the synchronous API", func=func ) + def require_async(self, func): + """Run a test only if using the asynchronous API.""" # unasync: off + return self._require( + lambda: not _IS_SYNC, + "This test only works with the asynchronous API", # unasync: off + func=func, + ) + def mongos_seeds(self): return ",".join("{}:{}".format(*address) for address in self.mongoses) @@ -863,35 +864,88 @@ def max_message_size_bytes(self): # Reusable client context client_context = ClientContext() +# Global event loop for async tests. +LOOP = None -def reset_client_context(): - if _IS_SYNC: - # sync tests don't need to reset a client context - return - elif client_context.client is not None: - client_context.client.close() - client_context.client = None - client_context._init_client() + +def get_loop() -> asyncio.AbstractEventLoop: + """Get the test suite's global event loop.""" + global LOOP + if LOOP is None: + try: + LOOP = asyncio.get_running_loop() + except RuntimeError: + # no running event loop, fallback to get_event_loop. + try: + # Ignore DeprecationWarning: There is no current event loop + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + LOOP = asyncio.get_event_loop() + except RuntimeError: + LOOP = asyncio.new_event_loop() + asyncio.set_event_loop(LOOP) + return LOOP class PyMongoTestCase(unittest.TestCase): + if not _IS_SYNC: + # An async TestCase that uses a single event loop for all tests. + # Inspired by TestCase. + def setUp(self): + pass + + def tearDown(self): + pass + + def addCleanup(self, func, /, *args, **kwargs): + self.addCleanup(*(func, *args), **kwargs) + + def _callSetUp(self): + self.setUp() + self._callAsync(self.setUp) + + def _callTestMethod(self, method): + self._callMaybeAsync(method) + + def _callTearDown(self): + self._callAsync(self.tearDown) + self.tearDown() + + def _callCleanup(self, function, *args, **kwargs): + self._callMaybeAsync(function, *args, **kwargs) + + def _callAsync(self, func, /, *args, **kwargs): + assert inspect.iscoroutinefunction(func), f"{func!r} is not an async function" + return get_loop().run_until_complete(func(*args, **kwargs)) + + def _callMaybeAsync(self, func, /, *args, **kwargs): + if inspect.iscoroutinefunction(func): + return get_loop().run_until_complete(func(*args, **kwargs)) + else: + return func(*args, **kwargs) + def assertEqualCommand(self, expected, actual, msg=None): self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg) def assertEqualReply(self, expected, actual, msg=None): self.assertEqual(sanitize_reply(expected), sanitize_reply(actual), msg) + @staticmethod + def configure_fail_point(client, command_args, off=False): + cmd = {"configureFailPoint": "failCommand"} + cmd.update(command_args) + if off: + cmd["mode"] = "off" + cmd.pop("data", None) + client.admin.command(cmd) + @contextmanager def fail_point(self, command_args): - cmd_on = SON([("configureFailPoint", "failCommand")]) - cmd_on.update(command_args) - client_context.client.admin.command(cmd_on) + self.configure_fail_point(client_context.client, command_args) try: yield finally: - client_context.client.admin.command( - "configureFailPoint", cmd_on["configureFailPoint"], mode="off" - ) + self.configure_fail_point(client_context.client, command_args, off=True) @contextmanager def fork( @@ -1136,8 +1190,6 @@ class IntegrationTest(PyMongoTestCase): @client_context.require_connection def setUp(self) -> None: - if not _IS_SYNC: - reset_client_context() if client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False): raise SkipTest("this test does not support load balancers") if client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False): @@ -1186,6 +1238,9 @@ def tearDown(self) -> None: def setup(): + if not _IS_SYNC: + # Set up the event loop. + get_loop() client_context.init() warnings.resetwarnings() warnings.simplefilter("always") diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 73e2824742..b772da3126 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -17,6 +17,7 @@ import asyncio import gc +import inspect import logging import multiprocessing import os @@ -30,30 +31,8 @@ import unittest import warnings from asyncio import iscoroutinefunction -from test.helpers import ( - COMPRESSORS, - IS_SRV, - MONGODB_API_VERSION, - MULTI_MONGOS_LB_URI, - TEST_LOADBALANCER, - TEST_SERVERLESS, - TLS_OPTIONS, - SystemCertsPatcher, - client_knobs, - db_pwd, - db_user, - global_knobs, - host, - is_server_resolvable, - port, - print_running_topology, - print_thread_stacks, - print_thread_tracebacks, - sanitize_cmd, - sanitize_reply, -) -from pymongo.uri_parser import parse_uri +from pymongo.asynchronous.uri_parser import parse_uri try: import ipaddress @@ -63,7 +42,6 @@ HAVE_IPADDRESS = False from contextlib import asynccontextmanager, contextmanager from functools import partial, wraps -from test.version import Version from typing import Any, Callable, Dict, Generator, overload from unittest import SkipTest from urllib.parse import quote_plus @@ -78,6 +56,32 @@ from pymongo.server_api import ServerApi from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] +sys.path[0:0] = [""] + +from test.helpers import ( + COMPRESSORS, + IS_SRV, + MONGODB_API_VERSION, + MULTI_MONGOS_LB_URI, + TEST_LOADBALANCER, + TEST_SERVERLESS, + TLS_OPTIONS, + SystemCertsPatcher, + client_knobs, + db_pwd, + db_user, + global_knobs, + host, + is_server_resolvable, + port, + print_running_topology, + print_thread_stacks, + print_thread_tracebacks, + sanitize_cmd, + sanitize_reply, +) +from test.version import Version + _IS_SYNC = False @@ -588,10 +592,10 @@ async def check(): @property async def supports_secondary_read_pref(self): - if self.has_secondaries: + if await self.has_secondaries: return True if self.is_mongos: - shard = await self.client.config.shards.find_one()["host"] # type:ignore[index] + shard = (await self.client.config.shards.find_one())["host"] # type:ignore[index] num_members = shard.count(",") + 1 return num_members > 1 return False @@ -676,7 +680,6 @@ async def is_topology_type(self, topologies): "single", "replicaset", "sharded", - "sharded-replicaset", "load-balanced", } if unknown: @@ -691,16 +694,6 @@ async def is_topology_type(self, topologies): return True if "sharded" in topologies and self.is_mongos: return True - if "sharded-replicaset" in topologies and self.is_mongos: - shards = await async_client_context.client.config.shards.find().to_list() - for shard in shards: - # For a 3-member RS-backed sharded cluster, shard['host'] - # will be 'replicaName/ip1:port1,ip2:port2,ip3:port3' - # Otherwise it will be 'ip1:port1' - host_spec = shard["host"] - if not len(host_spec.split("/")) > 1: - return False - return True return False def require_cluster_type(self, topologies=None): @@ -832,6 +825,14 @@ def require_sync(self, func): lambda: _IS_SYNC, "This test only works with the synchronous API", func=func ) + def require_async(self, func): + """Run a test only if using the asynchronous API.""" # unasync: off + return self._require( + lambda: not _IS_SYNC, + "This test only works with the asynchronous API", # unasync: off + func=func, + ) + def mongos_seeds(self): return ",".join("{}:{}".format(*address) for address in self.mongoses) @@ -865,35 +866,88 @@ async def max_message_size_bytes(self): # Reusable client context async_client_context = AsyncClientContext() +# Global event loop for async tests. +LOOP = None + + +def get_loop() -> asyncio.AbstractEventLoop: + """Get the test suite's global event loop.""" + global LOOP + if LOOP is None: + try: + LOOP = asyncio.get_running_loop() + except RuntimeError: + # no running event loop, fallback to get_event_loop. + try: + # Ignore DeprecationWarning: There is no current event loop + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + LOOP = asyncio.get_event_loop() + except RuntimeError: + LOOP = asyncio.new_event_loop() + asyncio.set_event_loop(LOOP) + return LOOP + + +class AsyncPyMongoTestCase(unittest.TestCase): + if not _IS_SYNC: + # An async TestCase that uses a single event loop for all tests. + # Inspired by IsolatedAsyncioTestCase. + async def asyncSetUp(self): + pass + + async def asyncTearDown(self): + pass -async def reset_client_context(): - if _IS_SYNC: - # sync tests don't need to reset a client context - return - elif async_client_context.client is not None: - await async_client_context.client.close() - async_client_context.client = None - await async_client_context._init_client() + def addAsyncCleanup(self, func, /, *args, **kwargs): + self.addCleanup(*(func, *args), **kwargs) + def _callSetUp(self): + self.setUp() + self._callAsync(self.asyncSetUp) + + def _callTestMethod(self, method): + self._callMaybeAsync(method) + + def _callTearDown(self): + self._callAsync(self.asyncTearDown) + self.tearDown() + + def _callCleanup(self, function, *args, **kwargs): + self._callMaybeAsync(function, *args, **kwargs) + + def _callAsync(self, func, /, *args, **kwargs): + assert inspect.iscoroutinefunction(func), f"{func!r} is not an async function" + return get_loop().run_until_complete(func(*args, **kwargs)) + + def _callMaybeAsync(self, func, /, *args, **kwargs): + if inspect.iscoroutinefunction(func): + return get_loop().run_until_complete(func(*args, **kwargs)) + else: + return func(*args, **kwargs) -class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase): def assertEqualCommand(self, expected, actual, msg=None): self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg) def assertEqualReply(self, expected, actual, msg=None): self.assertEqual(sanitize_reply(expected), sanitize_reply(actual), msg) + @staticmethod + async def configure_fail_point(client, command_args, off=False): + cmd = {"configureFailPoint": "failCommand"} + cmd.update(command_args) + if off: + cmd["mode"] = "off" + cmd.pop("data", None) + await client.admin.command(cmd) + @asynccontextmanager async def fail_point(self, command_args): - cmd_on = SON([("configureFailPoint", "failCommand")]) - cmd_on.update(command_args) - await async_client_context.client.admin.command(cmd_on) + await self.configure_fail_point(async_client_context.client, command_args) try: yield finally: - await async_client_context.client.admin.command( - "configureFailPoint", cmd_on["configureFailPoint"], mode="off" - ) + await self.configure_fail_point(async_client_context.client, command_args, off=True) @contextmanager def fork( @@ -970,7 +1024,7 @@ async def _unmanaged_async_mongo_client( auth_mech = kwargs.get("authMechanism", "") if async_client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC": # Only add the default username or password if one is not provided. - res = parse_uri(uri) + res = await parse_uri(uri) if ( not res["username"] and not res["password"] @@ -1001,7 +1055,7 @@ async def _async_mongo_client( auth_mech = kwargs.get("authMechanism", "") if async_client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC": # Only add the default username or password if one is not provided. - res = parse_uri(uri) + res = await parse_uri(uri) if ( not res["username"] and not res["password"] @@ -1124,15 +1178,15 @@ def unmanaged_simple_client( async def disable_replication(self, client): """Disable replication on all secondaries.""" - for h, p in client.secondaries: + for h, p in await client.secondaries: secondary = await self.async_single_client(h, p) - secondary.admin.command("configureFailPoint", "stopReplProducer", mode="alwaysOn") + await secondary.admin.command("configureFailPoint", "stopReplProducer", mode="alwaysOn") async def enable_replication(self, client): """Enable replication on all secondaries.""" - for h, p in client.secondaries: + for h, p in await client.secondaries: secondary = await self.async_single_client(h, p) - secondary.admin.command("configureFailPoint", "stopReplProducer", mode="off") + await secondary.admin.command("configureFailPoint", "stopReplProducer", mode="off") class AsyncUnitTest(AsyncPyMongoTestCase): @@ -1154,8 +1208,6 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase): @async_client_context.require_connection async def asyncSetUp(self) -> None: - if not _IS_SYNC: - await reset_client_context() if async_client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False): raise SkipTest("this test does not support load balancers") if async_client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False): @@ -1204,6 +1256,9 @@ async def asyncTearDown(self) -> None: async def async_setup(): + if not _IS_SYNC: + # Set up the event loop. + get_loop() await async_client_context.init() warnings.resetwarnings() warnings.simplefilter("always") diff --git a/test/asynchronous/helpers.py b/test/asynchronous/helpers.py index b5fc5d8ac4..7b021e8b44 100644 --- a/test/asynchronous/helpers.py +++ b/test/asynchronous/helpers.py @@ -15,6 +15,7 @@ """Shared constants and helper methods for pymongo, bson, and gridfs test suites.""" from __future__ import annotations +import asyncio import base64 import gc import multiprocessing @@ -30,6 +31,8 @@ import warnings from asyncio import iscoroutinefunction +from pymongo._asyncio_task import create_task + try: import ipaddress @@ -37,14 +40,14 @@ except ImportError: HAVE_IPADDRESS = False from functools import wraps -from typing import Any, Callable, Dict, Generator, no_type_check +from typing import Any, Callable, Dict, Generator, Optional, no_type_check from unittest import SkipTest from bson.son import SON from pymongo import common, message from pymongo.read_preferences import ReadPreference from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] -from pymongo.uri_parser import parse_uri +from pymongo.synchronous.uri_parser import parse_uri if HAVE_SSL: import ssl @@ -78,7 +81,7 @@ COMPRESSORS = os.environ.get("COMPRESSORS") MONGODB_API_VERSION = os.environ.get("MONGODB_API_VERSION") -TEST_LOADBALANCER = bool(os.environ.get("TEST_LOADBALANCER")) +TEST_LOADBALANCER = bool(os.environ.get("TEST_LOAD_BALANCER")) TEST_SERVERLESS = bool(os.environ.get("TEST_SERVERLESS")) SINGLE_MONGOS_LB_URI = os.environ.get("SINGLE_MONGOS_LB_URI") MULTI_MONGOS_LB_URI = os.environ.get("MULTI_MONGOS_LB_URI") @@ -369,3 +372,53 @@ def disable(self): os.environ.pop("SSL_CERT_FILE") else: os.environ["SSL_CERT_FILE"] = self.original_certs + + +if _IS_SYNC: + PARENT = threading.Thread +else: + PARENT = object + + +class ConcurrentRunner(PARENT): + def __init__(self, **kwargs): + if _IS_SYNC: + super().__init__(**kwargs) + self.name = kwargs.get("name", "ConcurrentRunner") + self.stopped = False + self.task = None + self.target = kwargs.get("target", None) + self.args = kwargs.get("args", []) + + if not _IS_SYNC: + + async def start(self): + self.task = create_task(self.run(), name=self.name) + + async def join(self, timeout: Optional[float] = None): # type: ignore[override] + if self.task is not None: + await asyncio.wait([self.task], timeout=timeout) + + def is_alive(self): + return not self.stopped + + async def run(self): + try: + await self.target(*self.args) + finally: + self.stopped = True + + +class ExceptionCatchingTask(ConcurrentRunner): + """A Task that stores any exception encountered while running.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.exc = None + + async def run(self): + try: + await super().run() + except BaseException as exc: + self.exc = exc + raise diff --git a/test/asynchronous/pymongo_mocks.py b/test/asynchronous/pymongo_mocks.py index ed2395bc98..40beb3c0dc 100644 --- a/test/asynchronous/pymongo_mocks.py +++ b/test/asynchronous/pymongo_mocks.py @@ -66,7 +66,7 @@ def __init__(self, server_description, topology, pool, topology_settings): def cancel_check(self): pass - def join(self): + async def join(self): pass def open(self): @@ -75,7 +75,7 @@ def open(self): def request_check(self): pass - def close(self): + async def close(self): self.opened = False diff --git a/test/asynchronous/test_async_cancellation.py b/test/asynchronous/test_async_cancellation.py new file mode 100644 index 0000000000..f450ea23cc --- /dev/null +++ b/test/asynchronous/test_async_cancellation.py @@ -0,0 +1,129 @@ +# Copyright 2025-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test that async cancellation performed by users clean up resources correctly.""" +from __future__ import annotations + +import asyncio +import sys +from test.asynchronous.utils import async_get_pool +from test.utils_shared import delay, one + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, connected + + +class TestAsyncCancellation(AsyncIntegrationTest): + async def test_async_cancellation_closes_connection(self): + pool = await async_get_pool(self.client) + await self.client.db.test.insert_one({"x": 1}) + self.addAsyncCleanup(self.client.db.test.delete_many, {}) + + conn = one(pool.conns) + + async def task(): + await self.client.db.test.find_one({"$where": delay(0.2)}) + + task = asyncio.create_task(task()) + + await asyncio.sleep(0.1) + + task.cancel() + with self.assertRaises(asyncio.CancelledError): + await task + + self.assertTrue(conn.closed) + + @async_client_context.require_transactions + async def test_async_cancellation_aborts_transaction(self): + await self.client.db.test.insert_one({"x": 1}) + self.addAsyncCleanup(self.client.db.test.delete_many, {}) + + session = self.client.start_session() + + async def callback(session): + await self.client.db.test.find_one({"$where": delay(0.2)}, session=session) + + async def task(): + await session.with_transaction(callback) + + task = asyncio.create_task(task()) + + await asyncio.sleep(0.1) + + task.cancel() + with self.assertRaises(asyncio.CancelledError): + await task + + self.assertFalse(session.in_transaction) + + @async_client_context.require_failCommand_blockConnection + async def test_async_cancellation_closes_cursor(self): + await self.client.db.test.insert_many([{"x": 1}, {"x": 2}]) + self.addAsyncCleanup(self.client.db.test.delete_many, {}) + + cursor = self.client.db.test.find({}, batch_size=1) + await cursor.next() + + # Make sure getMore commands block + fail_command = { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": {"failCommands": ["getMore"], "blockConnection": True, "blockTimeMS": 200}, + } + + async def task(): + async with self.fail_point(fail_command): + await cursor.next() + + task = asyncio.create_task(task()) + + await asyncio.sleep(0.1) + + task.cancel() + with self.assertRaises(asyncio.CancelledError): + await task + + self.assertTrue(cursor._killed) + + @async_client_context.require_change_streams + @async_client_context.require_failCommand_blockConnection + async def test_async_cancellation_closes_change_stream(self): + self.addAsyncCleanup(self.client.db.test.delete_many, {}) + change_stream = await self.client.db.test.watch(batch_size=2) + event = asyncio.Event() + + # Make sure getMore commands block + fail_command = { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": {"failCommands": ["getMore"], "blockConnection": True, "blockTimeMS": 200}, + } + + async def task(): + async with self.fail_point(fail_command): + await self.client.db.test.insert_many([{"x": 1}, {"x": 2}]) + event.set() + await change_stream.next() + + task = asyncio.create_task(task()) + + await event.wait() + + task.cancel() + with self.assertRaises(asyncio.CancelledError): + await task + + self.assertTrue(change_stream._closed) diff --git a/test/asynchronous/test_async_loop_safety.py b/test/asynchronous/test_async_loop_safety.py new file mode 100644 index 0000000000..7516cb8eeb --- /dev/null +++ b/test/asynchronous/test_async_loop_safety.py @@ -0,0 +1,36 @@ +# Copyright 2025-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test that the asynchronous API detects event loop changes and fails correctly.""" +from __future__ import annotations + +import asyncio +import unittest + +from pymongo import AsyncMongoClient + + +class TestClientLoopSafety(unittest.TestCase): + def test_client_errors_on_different_loop(self): + client = AsyncMongoClient() + loop1 = asyncio.new_event_loop() + loop1.run_until_complete(client.aconnect()) + loop2 = asyncio.new_event_loop() + with self.assertRaisesRegex( + RuntimeError, "Cannot use AsyncMongoClient in different event loop" + ): + loop2.run_until_complete(client.aconnect()) + loop1.run_until_complete(client.close()) + loop1.close() + loop2.close() diff --git a/test/asynchronous/test_auth.py b/test/asynchronous/test_auth.py index 7172152d69..904674db16 100644 --- a/test/asynchronous/test_auth.py +++ b/test/asynchronous/test_auth.py @@ -30,7 +30,7 @@ async_client_context, unittest, ) -from test.utils import AllowListEventListener, delay, ignore_deprecations +from test.utils_shared import AllowListEventListener, delay, ignore_deprecations import pytest diff --git a/test/asynchronous/test_auth_spec.py b/test/asynchronous/test_auth_spec.py index e9e43d5759..0a68658680 100644 --- a/test/asynchronous/test_auth_spec.py +++ b/test/asynchronous/test_auth_spec.py @@ -22,6 +22,8 @@ import warnings from test.asynchronous import AsyncPyMongoTestCase +import pytest + sys.path[0:0] = [""] from test import unittest @@ -30,6 +32,8 @@ from pymongo import AsyncMongoClient from pymongo.asynchronous.auth_oidc import OIDCCallback +pytestmark = pytest.mark.auth + _IS_SYNC = False _TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "auth") diff --git a/test/asynchronous/test_bulk.py b/test/asynchronous/test_bulk.py index 7191a412c1..65ed6e236a 100644 --- a/test/asynchronous/test_bulk.py +++ b/test/asynchronous/test_bulk.py @@ -24,7 +24,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, remove_all_users, unittest -from test.utils import async_wait_until +from test.utils_shared import async_wait_until from bson.binary import Binary, UuidRepresentation from bson.codec_options import CodecOptions @@ -301,7 +301,7 @@ async def test_numerous_inserts(self): async def test_bulk_max_message_size(self): await self.coll.delete_many({}) - self.addCleanup(self.coll.delete_many, {}) + self.addAsyncCleanup(self.coll.delete_many, {}) _16_MB = 16 * 1000 * 1000 # Generate a list of documents such that the first batched OP_MSG is # as close as possible to the 48MB limit. @@ -505,7 +505,7 @@ async def test_single_ordered_batch(self): async def test_single_error_ordered_batch(self): await self.coll.create_index("a", unique=True) - self.addCleanup(self.coll.drop_index, [("a", 1)]) + self.addAsyncCleanup(self.coll.drop_index, [("a", 1)]) requests: list = [ InsertOne({"b": 1, "a": 1}), UpdateOne({"b": 2}, {"$set": {"a": 1}}, upsert=True), @@ -547,7 +547,7 @@ async def test_single_error_ordered_batch(self): async def test_multiple_error_ordered_batch(self): await self.coll.create_index("a", unique=True) - self.addCleanup(self.coll.drop_index, [("a", 1)]) + self.addAsyncCleanup(self.coll.drop_index, [("a", 1)]) requests: list = [ InsertOne({"b": 1, "a": 1}), UpdateOne({"b": 2}, {"$set": {"a": 1}}, upsert=True), @@ -616,7 +616,7 @@ async def test_single_unordered_batch(self): async def test_single_error_unordered_batch(self): await self.coll.create_index("a", unique=True) - self.addCleanup(self.coll.drop_index, [("a", 1)]) + self.addAsyncCleanup(self.coll.drop_index, [("a", 1)]) requests: list = [ InsertOne({"b": 1, "a": 1}), UpdateOne({"b": 2}, {"$set": {"a": 1}}, upsert=True), @@ -659,7 +659,7 @@ async def test_single_error_unordered_batch(self): async def test_multiple_error_unordered_batch(self): await self.coll.create_index("a", unique=True) - self.addCleanup(self.coll.drop_index, [("a", 1)]) + self.addAsyncCleanup(self.coll.drop_index, [("a", 1)]) requests: list = [ InsertOne({"b": 1, "a": 1}), UpdateOne({"b": 2}, {"$set": {"a": 3}}, upsert=True), @@ -961,7 +961,6 @@ async def cause_wtimeout(self, requests, ordered): @async_client_context.require_replica_set @async_client_context.require_secondaries_count(1) async def test_write_concern_failure_ordered(self): - self.skipTest("Skipping until PYTHON-4865 is resolved.") details = None # Ensure we don't raise on wnote. @@ -1003,7 +1002,7 @@ async def test_write_concern_failure_ordered(self): await self.coll.delete_many({}) await self.coll.create_index("a", unique=True) - self.addCleanup(self.coll.drop_index, [("a", 1)]) + self.addAsyncCleanup(self.coll.drop_index, [("a", 1)]) # Fail due to write concern support as well # as duplicate key error on ordered batch. @@ -1078,7 +1077,7 @@ async def test_write_concern_failure_unordered(self): await self.coll.delete_many({}) await self.coll.create_index("a", unique=True) - self.addCleanup(self.coll.drop_index, [("a", 1)]) + self.addAsyncCleanup(self.coll.drop_index, [("a", 1)]) # Fail due to write concern support as well # as duplicate key error on unordered batch. diff --git a/test/asynchronous/test_change_stream.py b/test/asynchronous/test_change_stream.py index 08da00cc1e..0260cb7a82 100644 --- a/test/asynchronous/test_change_stream.py +++ b/test/asynchronous/test_change_stream.py @@ -36,7 +36,7 @@ unittest, ) from test.asynchronous.unified_format import generate_test_classes -from test.utils import ( +from test.utils_shared import ( AllowListEventListener, EventListener, OvertCommandListener, @@ -410,7 +410,14 @@ async def test_change_operations(self): expected_update_description = {"updatedFields": {"new": 1}, "removedFields": ["foo"]} if async_client_context.version.at_least(4, 5, 0): expected_update_description["truncatedArrays"] = [] - self.assertEqual(expected_update_description, change["updateDescription"]) + self.assertEqual( + expected_update_description, + { + k: v + for k, v in change["updateDescription"].items() + if k in expected_update_description + }, + ) # Replace. await self.watched_collection().replace_one({"new": 1}, {"foo": "bar"}) change = await change_stream.next() diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index 744a170be2..3a93613067 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -60,14 +60,16 @@ unittest, ) from test.asynchronous.pymongo_mocks import AsyncMockClient +from test.asynchronous.utils import ( + async_get_pool, + async_wait_until, + asyncAssertRaisesExactly, +) from test.test_binary import BinaryData -from test.utils import ( +from test.utils_shared import ( NTHREADS, CMAPListener, FunctionCallRecorder, - async_get_pool, - async_wait_until, - asyncAssertRaisesExactly, delay, gevent_monkey_patched, is_greenthread_patched, @@ -111,6 +113,7 @@ NetworkTimeout, OperationFailure, ServerSelectionTimeoutError, + WaitQueueTimeoutError, WriteConcernError, ) from pymongo.monitoring import ServerHeartbeatListener, ServerHeartbeatStartedEvent @@ -509,13 +512,13 @@ async def test_uri_option_precedence(self): async def test_connection_timeout_ms_propagates_to_DNS_resolver(self): # Patch the resolver. - from pymongo.srv_resolver import _resolve + from pymongo.asynchronous.srv_resolver import _resolve patched_resolver = FunctionCallRecorder(_resolve) - pymongo.srv_resolver._resolve = patched_resolver + pymongo.asynchronous.srv_resolver._resolve = patched_resolver def reset_resolver(): - pymongo.srv_resolver._resolve = _resolve + pymongo.asynchronous.srv_resolver._resolve = _resolve self.addCleanup(reset_resolver) @@ -604,7 +607,7 @@ def test_validate_suggestion(self): with self.assertRaisesRegex(ConfigurationError, expected): AsyncMongoClient(**{typo: "standard"}) # type: ignore[arg-type] - @patch("pymongo.srv_resolver._SrvResolver.get_hosts") + @patch("pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts") def test_detected_environment_logging(self, mock_get_hosts): normal_hosts = [ "normal.host.com", @@ -626,7 +629,7 @@ def test_detected_environment_logging(self, mock_get_hosts): logs = [record.getMessage() for record in cm.records if record.name == "pymongo.client"] self.assertEqual(len(logs), 7) - @patch("pymongo.srv_resolver._SrvResolver.get_hosts") + @patch("pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts") async def test_detected_environment_warning(self, mock_get_hosts): with self._caplog.at_level(logging.WARN): normal_hosts = [ @@ -743,7 +746,7 @@ async def test_min_pool_size(self): # Assert that if a socket is closed, a new one takes its place async with server._pool.checkout() as conn: - conn.close_conn(None) + await conn.close_conn(None) await async_wait_until( lambda: len(server._pool.conns) == 10, "a closed socket gets replaced from the pool", @@ -846,6 +849,58 @@ async def test_init_disconnected_with_auth(self): with self.assertRaises(ConnectionFailure): await c.pymongo_test.test.find_one() + @async_client_context.require_replica_set + @async_client_context.require_no_load_balancer + @async_client_context.require_tls + async def test_init_disconnected_with_srv(self): + c = await self.async_rs_or_single_client( + "mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True + ) + # nodes returns an empty set if not connected + self.assertEqual(c.nodes, frozenset()) + # topology_description returns the initial seed description if not connected + topology_description = c.topology_description + self.assertEqual(topology_description.topology_type, TOPOLOGY_TYPE.Unknown) + self.assertEqual( + { + ("test1.test.build.10gen.cc", None): ServerDescription( + ("test1.test.build.10gen.cc", None) + ) + }, + topology_description.server_descriptions(), + ) + + # address causes client to block until connected + self.assertIsNotNone(await c.address) + # Initial seed topology and connected topology have the same ID + self.assertEqual( + c._topology._topology_id, topology_description._topology_settings._topology_id + ) + await c.close() + + c = await self.async_rs_or_single_client( + "mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True + ) + # primary causes client to block until connected + await c.primary + self.assertIsNotNone(c._topology) + await c.close() + + c = await self.async_rs_or_single_client( + "mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True + ) + # secondaries causes client to block until connected + await c.secondaries + self.assertIsNotNone(c._topology) + await c.close() + + c = await self.async_rs_or_single_client( + "mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True + ) + # arbiters causes client to block until connected + await c.arbiters + self.assertIsNotNone(c._topology) + async def test_equality(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) c = await self.async_rs_or_single_client(seed, connect=False) @@ -930,6 +985,15 @@ async def test_repr(self): async with eval(the_repr) as client_two: self.assertEqual(client_two, client) + async def test_repr_srv_host(self): + client = AsyncMongoClient("mongodb+srv://test1.test.build.10gen.cc/", connect=False) + # before srv resolution + self.assertIn("host='mongodb+srv://test1.test.build.10gen.cc'", repr(client)) + await client.aconnect() + # after srv resolution + self.assertIn("host=['localhost.test.build.10gen.cc:", repr(client)) + await client.close() + async def test_getters(self): await async_wait_until( lambda: async_client_context.nodes == self.client.nodes, "find all nodes" @@ -1258,7 +1322,6 @@ async def test_socket_timeout(self): no_timeout = self.client timeout_sec = 1 timeout = await self.async_rs_or_single_client(socketTimeoutMS=1000 * timeout_sec) - self.addAsyncCleanup(timeout.close) await no_timeout.pymongo_test.drop_collection("test") await no_timeout.pymongo_test.test.insert_one({"x": 1}) @@ -1311,13 +1374,21 @@ async def test_server_selection_timeout(self): self.assertAlmostEqual(30, client.options.server_selection_timeout) async def test_waitQueueTimeoutMS(self): - client = await self.async_rs_or_single_client(waitQueueTimeoutMS=2000) - self.assertEqual((await async_get_pool(client)).opts.wait_queue_timeout, 2) + listener = CMAPListener() + client = await self.async_rs_or_single_client( + waitQueueTimeoutMS=10, maxPoolSize=1, event_listeners=[listener] + ) + pool = await async_get_pool(client) + self.assertEqual(pool.opts.wait_queue_timeout, 0.01) + async with pool.checkout(): + with self.assertRaises(WaitQueueTimeoutError): + await client.test.command("ping") + self.assertFalse(listener.events_by_type(monitoring.PoolClearedEvent)) async def test_socketKeepAlive(self): pool = await async_get_pool(self.client) async with pool.checkout() as conn: - keepalive = conn.conn.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE) + keepalive = conn.conn.sock.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE) self.assertTrue(keepalive) @no_type_check @@ -1517,7 +1588,7 @@ async def test_exhaust_network_error(self): # Cause a network error. conn = one(pool.conns) - conn.conn.close() + await conn.conn.close() cursor = collection.find(cursor_type=CursorType.EXHAUST) with self.assertRaises(ConnectionFailure): await anext(cursor) @@ -1542,7 +1613,7 @@ async def test_auth_network_error(self): # Cause a network error on the actual socket. pool = await async_get_pool(c) conn = one(pool.conns) - conn.conn.close() + await conn.conn.close() # AsyncConnection.authenticate logs, but gets a socket.error. Should be # reraised as AutoReconnect. @@ -1790,6 +1861,29 @@ async def stall_connect(*args, **kwargs): # Each ping command should not take more than 2 seconds self.assertLess(total, 2) + async def test_background_connections_log_on_error(self): + with self.assertLogs("pymongo.client", level="ERROR") as cm: + client = await self.async_rs_or_single_client(minPoolSize=1) + # Create a single connection in the pool. + await client.admin.command("ping") + + # Cause new connections to fail. + pool = await async_get_pool(client) + + async def fail_connect(*args, **kwargs): + raise Exception("failed to connect") + + pool.connect = fail_connect + # Un-patch Pool.connect to break the cyclic reference. + self.addCleanup(delattr, pool, "connect") + + await pool.reset_without_pause() + + await async_wait_until( + lambda: "failed to connect" in "".join(cm.output), "start creating connections" + ) + self.assertIn("MongoClient background task encountered an error", "".join(cm.output)) + @async_client_context.require_replica_set async def test_direct_connection(self): # direct_connection=True should result in Single topology. @@ -1824,20 +1918,20 @@ def server_description_count(): return i gc.collect() - with client_knobs(min_heartbeat_interval=0.003): + with client_knobs(min_heartbeat_interval=0.002): client = self.simple_client( - "invalid:27017", heartbeatFrequencyMS=3, serverSelectionTimeoutMS=150 + "invalid:27017", heartbeatFrequencyMS=2, serverSelectionTimeoutMS=200 ) initial_count = server_description_count() with self.assertRaises(ServerSelectionTimeoutError): await client.test.test.find_one() gc.collect() final_count = server_description_count() + await client.close() # If a bug like PYTHON-2433 is reintroduced then too many # ServerDescriptions will be kept alive and this test will fail: - # AssertionError: 19 != 46 within 15 delta (27 difference) - # On Python 3.11 we seem to get more of a delta. - self.assertAlmostEqual(initial_count, final_count, delta=20) + # AssertionError: 11 != 47 within 20 delta (36 difference) + self.assertAlmostEqual(initial_count, final_count, delta=30) @async_client_context.require_failCommand_fail_point async def test_network_error_message(self): @@ -1877,28 +1971,37 @@ async def test_service_name_from_kwargs(self): srvServiceName="customname", connect=False, ) + await client.aconnect() self.assertEqual(client._topology_settings.srv_service_name, "customname") + await client.close() client = AsyncMongoClient( "mongodb+srv://user:password@test22.test.build.10gen.cc" "/?srvServiceName=shouldbeoverriden", srvServiceName="customname", connect=False, ) + await client.aconnect() self.assertEqual(client._topology_settings.srv_service_name, "customname") + await client.close() client = AsyncMongoClient( "mongodb+srv://user:password@test22.test.build.10gen.cc/?srvServiceName=customname", connect=False, ) + await client.aconnect() self.assertEqual(client._topology_settings.srv_service_name, "customname") + await client.close() async def test_srv_max_hosts_kwarg(self): client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/") + await client.aconnect() self.assertGreater(len(client.topology_description.server_descriptions()), 1) client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1) + await client.aconnect() self.assertEqual(len(client.topology_description.server_descriptions()), 1) client = self.simple_client( "mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=1", srvmaxhosts=2 ) + await client.aconnect() self.assertEqual(len(client.topology_description.server_descriptions()), 2) @unittest.skipIf( @@ -2202,7 +2305,7 @@ async def test_exhaust_query_network_error(self): # Cause a network error. conn = one(pool.conns) - conn.conn.close() + await conn.conn.close() cursor = collection.find(cursor_type=CursorType.EXHAUST) with self.assertRaises(ConnectionFailure): @@ -2230,7 +2333,7 @@ async def test_exhaust_getmore_network_error(self): # Cause a network error. conn = cursor._sock_mgr.conn - conn.conn.close() + await conn.conn.close() # A getmore fails. with self.assertRaises(ConnectionFailure): diff --git a/test/asynchronous/test_client_bulk_write.py b/test/asynchronous/test_client_bulk_write.py index 282009f554..9eb15298a6 100644 --- a/test/asynchronous/test_client_bulk_write.py +++ b/test/asynchronous/test_client_bulk_write.py @@ -25,7 +25,7 @@ async_client_context, unittest, ) -from test.utils import ( +from test.utils_shared import ( OvertCommandListener, ) from unittest.mock import patch @@ -651,7 +651,6 @@ async def test_timeout_in_multi_batch_bulk_write(self): _OVERHEAD = 500 internal_client = await self.async_rs_or_single_client(timeoutMS=None) - self.addAsyncCleanup(internal_client.close) collection = internal_client.db["coll"] self.addAsyncCleanup(collection.drop) diff --git a/test/asynchronous/test_client_context.py b/test/asynchronous/test_client_context.py index 6a195eb6b8..afca1c0b26 100644 --- a/test/asynchronous/test_client_context.py +++ b/test/asynchronous/test_client_context.py @@ -47,20 +47,14 @@ def test_serverless(self): ) def test_enableTestCommands_is_disabled(self): - if not os.environ.get("PYMONGO_DISABLE_TEST_COMMANDS"): - raise SkipTest("PYMONGO_DISABLE_TEST_COMMANDS is not set") + if not os.environ.get("DISABLE_TEST_COMMANDS"): + raise SkipTest("DISABLE_TEST_COMMANDS is not set") self.assertFalse( async_client_context.test_commands_enabled, - "enableTestCommands must be disabled when PYMONGO_DISABLE_TEST_COMMANDS is set.", + "enableTestCommands must be disabled when DISABLE_TEST_COMMANDS is set.", ) - def test_setdefaultencoding_worked(self): - if not os.environ.get("SETDEFAULTENCODING"): - raise SkipTest("SETDEFAULTENCODING is not set") - - self.assertEqual(sys.getdefaultencoding(), os.environ["SETDEFAULTENCODING"]) - def test_free_threading_is_enabled(self): if "free-threading build" not in sys.version: raise SkipTest("this test requires the Python free-threading build") diff --git a/test/asynchronous/test_collation.py b/test/asynchronous/test_collation.py index d7fd85b168..05e548c79e 100644 --- a/test/asynchronous/test_collation.py +++ b/test/asynchronous/test_collation.py @@ -18,7 +18,7 @@ import functools import warnings from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.utils import EventListener, OvertCommandListener +from test.utils_shared import EventListener, OvertCommandListener from typing import Any from pymongo.asynchronous.helpers import anext diff --git a/test/asynchronous/test_collection.py b/test/asynchronous/test_collection.py index beb58012a8..00ed020d88 100644 --- a/test/asynchronous/test_collection.py +++ b/test/asynchronous/test_collection.py @@ -21,6 +21,7 @@ import sys from codecs import utf_8_decode from collections import defaultdict +from test.asynchronous.utils import async_get_pool, async_is_mongos from typing import Any, Iterable, no_type_check from pymongo.asynchronous.database import AsyncDatabase @@ -33,12 +34,10 @@ AsyncUnitTest, async_client_context, ) -from test.utils import ( +from test.utils_shared import ( IMPOSSIBLE_WRITE_CONCERN, EventListener, OvertCommandListener, - async_get_pool, - async_is_mongos, async_wait_until, ) diff --git a/test/asynchronous/test_comment.py b/test/asynchronous/test_comment.py index be3626a8b8..d3ddaf2b65 100644 --- a/test/asynchronous/test_comment.py +++ b/test/asynchronous/test_comment.py @@ -22,7 +22,7 @@ sys.path[0:0] = [""] from asyncio import iscoroutinefunction from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.utils import OvertCommandListener +from test.utils_shared import OvertCommandListener from bson.dbref import DBRef from pymongo.asynchronous.command_cursor import AsyncCommandCursor diff --git a/test/asynchronous/test_concurrency.py b/test/asynchronous/test_concurrency.py index 1683b8413b..193ecf05c8 100644 --- a/test/asynchronous/test_concurrency.py +++ b/test/asynchronous/test_concurrency.py @@ -18,7 +18,7 @@ import asyncio import time from test.asynchronous import AsyncIntegrationTest, async_client_context -from test.utils import delay +from test.utils_shared import delay _IS_SYNC = False diff --git a/test/asynchronous/test_connection_monitoring.py b/test/asynchronous/test_connection_monitoring.py new file mode 100644 index 0000000000..c6dc6f0a69 --- /dev/null +++ b/test/asynchronous/test_connection_monitoring.py @@ -0,0 +1,472 @@ +# Copyright 2019-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Execute Transactions Spec tests.""" +from __future__ import annotations + +import asyncio +import os +import sys +import time +from pathlib import Path +from test.asynchronous.utils import async_get_pool, async_get_pools + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs, unittest +from test.asynchronous.pymongo_mocks import DummyMonitor +from test.asynchronous.utils_spec_runner import AsyncSpecTestCreator, SpecRunnerTask +from test.utils_shared import ( + CMAPListener, + async_wait_until, + camel_to_snake, +) + +from bson.objectid import ObjectId +from bson.son import SON +from pymongo.asynchronous.pool import PoolState, _PoolClosedError +from pymongo.errors import ( + ConnectionFailure, + OperationFailure, + PyMongoError, + WaitQueueTimeoutError, +) +from pymongo.monitoring import ( + ConnectionCheckedInEvent, + ConnectionCheckedOutEvent, + ConnectionCheckOutFailedEvent, + ConnectionCheckOutFailedReason, + ConnectionCheckOutStartedEvent, + ConnectionClosedEvent, + ConnectionClosedReason, + ConnectionCreatedEvent, + ConnectionReadyEvent, + PoolClearedEvent, + PoolClosedEvent, + PoolCreatedEvent, + PoolReadyEvent, +) +from pymongo.read_preferences import ReadPreference +from pymongo.topology_description import updated_topology_description + +_IS_SYNC = False + +OBJECT_TYPES = { + # Event types. + "ConnectionCheckedIn": ConnectionCheckedInEvent, + "ConnectionCheckedOut": ConnectionCheckedOutEvent, + "ConnectionCheckOutFailed": ConnectionCheckOutFailedEvent, + "ConnectionClosed": ConnectionClosedEvent, + "ConnectionCreated": ConnectionCreatedEvent, + "ConnectionReady": ConnectionReadyEvent, + "ConnectionCheckOutStarted": ConnectionCheckOutStartedEvent, + "ConnectionPoolCreated": PoolCreatedEvent, + "ConnectionPoolReady": PoolReadyEvent, + "ConnectionPoolCleared": PoolClearedEvent, + "ConnectionPoolClosed": PoolClosedEvent, + # Error types. + "PoolClosedError": _PoolClosedError, + "WaitQueueTimeoutError": WaitQueueTimeoutError, +} + + +class AsyncTestCMAP(AsyncIntegrationTest): + # Location of JSON test specifications. + if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "connection_monitoring") + else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "connection_monitoring") + + # Test operations: + + async def start(self, op): + """Run the 'start' thread operation.""" + target = op["target"] + thread = SpecRunnerTask(target) + await thread.start() + self.targets[target] = thread + + async def wait(self, op): + """Run the 'wait' operation.""" + await asyncio.sleep(op["ms"] / 1000.0) + + async def wait_for_thread(self, op): + """Run the 'waitForThread' operation.""" + target = op["target"] + thread = self.targets[target] + await thread.stop() + await thread.join() + if thread.exc: + raise thread.exc + self.assertFalse(thread.ops) + + async def wait_for_event(self, op): + """Run the 'waitForEvent' operation.""" + event = OBJECT_TYPES[op["event"]] + count = op["count"] + timeout = op.get("timeout", 10000) / 1000.0 + await async_wait_until( + lambda: self.listener.event_count(event) >= count, + f"find {count} {event} event(s)", + timeout=timeout, + ) + + async def check_out(self, op): + """Run the 'checkOut' operation.""" + label = op["label"] + async with self.pool.checkout() as conn: + # Call 'pin_cursor' so we can hold the socket. + conn.pin_cursor() + if label: + self.labels[label] = conn + else: + self.addAsyncCleanup(conn.close_conn, None) + + async def check_in(self, op): + """Run the 'checkIn' operation.""" + label = op["connection"] + conn = self.labels[label] + await self.pool.checkin(conn) + + async def ready(self, op): + """Run the 'ready' operation.""" + await self.pool.ready() + + async def clear(self, op): + """Run the 'clear' operation.""" + if "interruptInUseConnections" in op: + await self.pool.reset(interrupt_connections=op["interruptInUseConnections"]) + else: + await self.pool.reset() + + async def close(self, op): + """Run the 'close' operation.""" + await self.pool.close() + + async def run_operation(self, op): + """Run a single operation in a test.""" + op_name = camel_to_snake(op["name"]) + thread = op["thread"] + meth = getattr(self, op_name) + if thread: + await self.targets[thread].schedule(lambda: meth(op)) + else: + await meth(op) + + async def run_operations(self, ops): + """Run a test's operations.""" + for op in ops: + self._ops.append(op) + await self.run_operation(op) + + def check_object(self, actual, expected): + """Assert that the actual object matches the expected object.""" + self.assertEqual(type(actual), OBJECT_TYPES[expected["type"]]) + for attr, expected_val in expected.items(): + if attr == "type": + continue + c2s = camel_to_snake(attr) + if c2s == "interrupt_in_use_connections": + c2s = "interrupt_connections" + actual_val = getattr(actual, c2s) + if expected_val == 42: + self.assertIsNotNone(actual_val) + else: + self.assertEqual(actual_val, expected_val) + + def check_event(self, actual, expected): + """Assert that the actual event matches the expected event.""" + self.check_object(actual, expected) + + def actual_events(self, ignore): + """Return all the non-ignored events.""" + ignore = tuple(OBJECT_TYPES[name] for name in ignore) + return [event for event in self.listener.events if not isinstance(event, ignore)] + + def check_events(self, events, ignore): + """Check the events of a test.""" + actual_events = self.actual_events(ignore) + for actual, expected in zip(actual_events, events): + self.logs.append(f"Checking event actual: {actual!r} vs expected: {expected!r}") + self.check_event(actual, expected) + + if len(events) > len(actual_events): + self.fail(f"missing events: {events[len(actual_events) :]!r}") + + def check_error(self, actual, expected): + message = expected.pop("message") + self.check_object(actual, expected) + self.assertIn(message, str(actual)) + + async def set_fail_point(self, command_args): + if not async_client_context.supports_failCommand_fail_point: + self.skipTest("failCommand fail point must be supported") + await self.configure_fail_point(self.client, command_args) + + async def run_scenario(self, scenario_def, test): + """Run a CMAP spec test.""" + self.logs: list = [] + self.assertEqual(scenario_def["version"], 1) + self.assertIn(scenario_def["style"], ["unit", "integration"]) + self.listener = CMAPListener() + self._ops: list = [] + + # Configure the fail point before creating the client. + if "failPoint" in test: + fp = test["failPoint"] + await self.set_fail_point(fp) + self.addAsyncCleanup( + self.set_fail_point, {"configureFailPoint": fp["configureFailPoint"], "mode": "off"} + ) + + opts = test["poolOptions"].copy() + opts["event_listeners"] = [self.listener] + opts["_monitor_class"] = DummyMonitor + opts["connect"] = False + # Support backgroundThreadIntervalMS, default to 50ms. + interval = opts.pop("backgroundThreadIntervalMS", 50) + if interval < 0: + kill_cursor_frequency = 99999999 + else: + kill_cursor_frequency = interval / 1000.0 + with client_knobs(kill_cursor_frequency=kill_cursor_frequency, min_heartbeat_interval=0.05): + client = await self.async_single_client(**opts) + # Update the SD to a known type because the DummyMonitor will not. + # Note we cannot simply call topology.on_change because that would + # internally call pool.ready() which introduces unexpected + # PoolReadyEvents. Instead, update the initial state before + # opening the Topology. + td = async_client_context.client._topology.description + sd = td.server_descriptions()[ + (await async_client_context.host, await async_client_context.port) + ] + client._topology._description = updated_topology_description( + client._topology._description, sd + ) + # When backgroundThreadIntervalMS is negative we do not start the + # background thread to ensure it never runs. + if interval < 0: + await client._topology.open() + else: + await client._get_topology() + self.pool = list(client._topology._servers.values())[0].pool + + # Map of target names to Thread objects. + self.targets: dict = {} + # Map of label names to AsyncConnection objects + self.labels: dict = {} + + async def cleanup(): + for t in self.targets.values(): + await t.stop() + for t in self.targets.values(): + await t.join(5) + for conn in self.labels.values(): + await conn.close_conn(None) + + self.addAsyncCleanup(cleanup) + + try: + if test["error"]: + with self.assertRaises(PyMongoError) as ctx: + await self.run_operations(test["operations"]) + self.check_error(ctx.exception, test["error"]) + else: + await self.run_operations(test["operations"]) + + self.check_events(test["events"], test["ignore"]) + except Exception: + # Print the events after a test failure. + print("\nFailed test: {!r}".format(test["description"])) + print("Operations:") + for op in self._ops: + print(op) + print("Threads:") + print(self.targets) + print("AsyncConnections:") + print(self.labels) + print("Events:") + for event in self.listener.events: + print(event) + print("Log:") + for log in self.logs: + print(log) + raise + + POOL_OPTIONS = { + "maxPoolSize": 50, + "minPoolSize": 1, + "maxIdleTimeMS": 10000, + "waitQueueTimeoutMS": 10000, + } + + # + # Prose tests. Numbers correspond to the prose test number in the spec. + # + async def test_1_client_connection_pool_options(self): + client = await self.async_rs_or_single_client(**self.POOL_OPTIONS) + pool_opts = (await async_get_pool(client)).opts + self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS) + + async def test_2_all_client_pools_have_same_options(self): + client = await self.async_rs_or_single_client(**self.POOL_OPTIONS) + await client.admin.command("ping") + # Discover at least one secondary. + if await async_client_context.has_secondaries: + await client.admin.command("ping", read_preference=ReadPreference.SECONDARY) + pools = await async_get_pools(client) + pool_opts = pools[0].opts + + self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS) + for pool in pools[1:]: + self.assertEqual(pool.opts, pool_opts) + + async def test_3_uri_connection_pool_options(self): + opts = "&".join([f"{k}={v}" for k, v in self.POOL_OPTIONS.items()]) + uri = f"mongodb://{await async_client_context.pair}/?{opts}" + client = await self.async_rs_or_single_client(uri) + pool_opts = (await async_get_pool(client)).opts + self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS) + + async def test_4_subscribe_to_events(self): + listener = CMAPListener() + client = await self.async_single_client(event_listeners=[listener]) + self.assertEqual(listener.event_count(PoolCreatedEvent), 1) + + # Creates a new connection. + await client.admin.command("ping") + self.assertEqual(listener.event_count(ConnectionCheckOutStartedEvent), 1) + self.assertEqual(listener.event_count(ConnectionCreatedEvent), 1) + self.assertEqual(listener.event_count(ConnectionReadyEvent), 1) + self.assertEqual(listener.event_count(ConnectionCheckedOutEvent), 1) + self.assertEqual(listener.event_count(ConnectionCheckedInEvent), 1) + + # Uses the existing connection. + await client.admin.command("ping") + self.assertEqual(listener.event_count(ConnectionCheckOutStartedEvent), 2) + self.assertEqual(listener.event_count(ConnectionCheckedOutEvent), 2) + self.assertEqual(listener.event_count(ConnectionCheckedInEvent), 2) + + await client.close() + self.assertEqual(listener.event_count(PoolClosedEvent), 1) + self.assertEqual(listener.event_count(ConnectionClosedEvent), 1) + + async def test_5_check_out_fails_connection_error(self): + listener = CMAPListener() + client = await self.async_single_client(event_listeners=[listener]) + pool = await async_get_pool(client) + + def mock_connect(*args, **kwargs): + raise ConnectionFailure("connect failed") + + pool.connect = mock_connect + # Un-patch Pool.connect to break the cyclic reference. + self.addCleanup(delattr, pool, "connect") + + # Attempt to create a new connection. + with self.assertRaisesRegex(ConnectionFailure, "connect failed"): + await client.admin.command("ping") + + self.assertIsInstance(listener.events[0], PoolCreatedEvent) + self.assertIsInstance(listener.events[1], PoolReadyEvent) + self.assertIsInstance(listener.events[2], ConnectionCheckOutStartedEvent) + self.assertIsInstance(listener.events[3], ConnectionCheckOutFailedEvent) + self.assertIsInstance(listener.events[4], PoolClearedEvent) + + failed_event = listener.events[3] + self.assertEqual(failed_event.reason, ConnectionCheckOutFailedReason.CONN_ERROR) + + @async_client_context.require_no_fips + async def test_5_check_out_fails_auth_error(self): + listener = CMAPListener() + client = await self.async_single_client_noauth( + username="notauser", password="fail", event_listeners=[listener] + ) + + # Attempt to create a new connection. + with self.assertRaisesRegex(OperationFailure, "failed"): + await client.admin.command("ping") + + self.assertIsInstance(listener.events[0], PoolCreatedEvent) + self.assertIsInstance(listener.events[1], PoolReadyEvent) + self.assertIsInstance(listener.events[2], ConnectionCheckOutStartedEvent) + self.assertIsInstance(listener.events[3], ConnectionCreatedEvent) + # Error happens here. + self.assertIsInstance(listener.events[4], ConnectionClosedEvent) + self.assertIsInstance(listener.events[5], ConnectionCheckOutFailedEvent) + self.assertEqual(listener.events[5].reason, ConnectionCheckOutFailedReason.CONN_ERROR) + + # + # Extra non-spec tests + # + def assertRepr(self, obj): + new_obj = eval(repr(obj)) + self.assertEqual(type(new_obj), type(obj)) + self.assertEqual(repr(new_obj), repr(obj)) + + async def test_events_repr(self): + host = ("localhost", 27017) + self.assertRepr(ConnectionCheckedInEvent(host, 1)) + self.assertRepr(ConnectionCheckedOutEvent(host, 1, time.monotonic())) + self.assertRepr( + ConnectionCheckOutFailedEvent( + host, ConnectionCheckOutFailedReason.POOL_CLOSED, time.monotonic() + ) + ) + self.assertRepr(ConnectionClosedEvent(host, 1, ConnectionClosedReason.POOL_CLOSED)) + self.assertRepr(ConnectionCreatedEvent(host, 1)) + self.assertRepr(ConnectionReadyEvent(host, 1, time.monotonic())) + self.assertRepr(ConnectionCheckOutStartedEvent(host)) + self.assertRepr(PoolCreatedEvent(host, {})) + self.assertRepr(PoolClearedEvent(host)) + self.assertRepr(PoolClearedEvent(host, service_id=ObjectId())) + self.assertRepr(PoolClosedEvent(host)) + + async def test_close_leaves_pool_unpaused(self): + listener = CMAPListener() + client = await self.async_single_client(event_listeners=[listener]) + await client.admin.command("ping") + pool = await async_get_pool(client) + await client.close() + self.assertEqual(1, listener.event_count(PoolClosedEvent)) + self.assertEqual(PoolState.CLOSED, pool.state) + # Checking out a connection should fail + with self.assertRaises(_PoolClosedError): + async with pool.checkout(): + pass + + +def create_test(scenario_def, test, name): + async def run_scenario(self): + await self.run_scenario(scenario_def, test) + + return run_scenario + + +class CMAPSpecTestCreator(AsyncSpecTestCreator): + def tests(self, scenario_def): + """Extract the tests from a spec file. + + CMAP tests do not have a 'tests' field. The whole file represents + a single test case. + """ + return [scenario_def] + + +test_creator = CMAPSpecTestCreator(create_test, AsyncTestCMAP, AsyncTestCMAP.TEST_PATH) +test_creator.create_tests() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_connections_survive_primary_stepdown_spec.py b/test/asynchronous/test_connections_survive_primary_stepdown_spec.py index 7c11742a90..92c750c4fe 100644 --- a/test/asynchronous/test_connections_survive_primary_stepdown_spec.py +++ b/test/asynchronous/test_connections_survive_primary_stepdown_spec.py @@ -16,6 +16,7 @@ from __future__ import annotations import sys +from test.asynchronous.utils import async_ensure_all_connected sys.path[0:0] = [""] @@ -25,9 +26,8 @@ unittest, ) from test.asynchronous.helpers import async_repl_set_step_down -from test.utils import ( +from test.utils_shared import ( CMAPListener, - async_ensure_all_connected, ) from bson import SON diff --git a/test/asynchronous/test_csot.py b/test/asynchronous/test_csot.py new file mode 100644 index 0000000000..9e928c2251 --- /dev/null +++ b/test/asynchronous/test_csot.py @@ -0,0 +1,118 @@ +# Copyright 2022-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the CSOT unified spec tests.""" +from __future__ import annotations + +import os +import sys +from pathlib import Path + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.asynchronous.unified_format import generate_test_classes + +import pymongo +from pymongo import _csot +from pymongo.errors import PyMongoError + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "csot") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "csot") + +# Generate unified tests. +globals().update(generate_test_classes(TEST_PATH, module=__name__)) + + +class TestCSOT(AsyncIntegrationTest): + RUN_ON_SERVERLESS = True + RUN_ON_LOAD_BALANCER = True + + async def test_timeout_nested(self): + if os.environ.get("SKIP_CSOT_TESTS", ""): + raise unittest.SkipTest("SKIP_CSOT_TESTS is set, skipping...") + coll = self.db.coll + self.assertEqual(_csot.get_timeout(), None) + self.assertEqual(_csot.get_deadline(), float("inf")) + self.assertEqual(_csot.get_rtt(), 0.0) + with pymongo.timeout(10): + await coll.find_one() + self.assertEqual(_csot.get_timeout(), 10) + deadline_10 = _csot.get_deadline() + + # Capped at the original 10 deadline. + with pymongo.timeout(15): + await coll.find_one() + self.assertEqual(_csot.get_timeout(), 15) + self.assertEqual(_csot.get_deadline(), deadline_10) + + # Should be reset to previous values + self.assertEqual(_csot.get_timeout(), 10) + self.assertEqual(_csot.get_deadline(), deadline_10) + await coll.find_one() + + with pymongo.timeout(5): + await coll.find_one() + self.assertEqual(_csot.get_timeout(), 5) + self.assertLess(_csot.get_deadline(), deadline_10) + + # Should be reset to previous values + self.assertEqual(_csot.get_timeout(), 10) + self.assertEqual(_csot.get_deadline(), deadline_10) + await coll.find_one() + + # Should be reset to previous values + self.assertEqual(_csot.get_timeout(), None) + self.assertEqual(_csot.get_deadline(), float("inf")) + self.assertEqual(_csot.get_rtt(), 0.0) + + @async_client_context.require_change_streams + async def test_change_stream_can_resume_after_timeouts(self): + if os.environ.get("SKIP_CSOT_TESTS", ""): + raise unittest.SkipTest("SKIP_CSOT_TESTS is set, skipping...") + coll = self.db.test + await coll.insert_one({}) + async with await coll.watch() as stream: + with pymongo.timeout(0.1): + with self.assertRaises(PyMongoError) as ctx: + await stream.next() + self.assertTrue(ctx.exception.timeout) + self.assertTrue(stream.alive) + with self.assertRaises(PyMongoError) as ctx: + await stream.try_next() + self.assertTrue(ctx.exception.timeout) + self.assertTrue(stream.alive) + # Resume before the insert on 3.6 because 4.0 is required to avoid skipping documents + if async_client_context.version < (4, 0): + await stream.try_next() + await coll.insert_one({}) + with pymongo.timeout(10): + self.assertTrue(await stream.next()) + self.assertTrue(stream.alive) + # Timeout applies to entire next() call, not only individual commands. + with pymongo.timeout(0.5): + with self.assertRaises(PyMongoError) as ctx: + await stream.next() + self.assertTrue(ctx.exception.timeout) + self.assertTrue(stream.alive) + self.assertFalse(stream.alive) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_cursor.py b/test/asynchronous/test_cursor.py index d843ffb4aa..3c8570f336 100644 --- a/test/asynchronous/test_cursor.py +++ b/test/asynchronous/test_cursor.py @@ -31,7 +31,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.utils import ( +from test.utils_shared import ( AllowListEventListener, EventListener, OvertCommandListener, @@ -1401,7 +1401,7 @@ async def test_to_list_empty(self): async def test_to_list_length(self): coll = self.db.test await coll.insert_many([{} for _ in range(5)]) - self.addCleanup(coll.drop) + self.addAsyncCleanup(coll.drop) c = coll.find() docs = await c.to_list(3) self.assertEqual(len(docs), 3) @@ -1812,6 +1812,7 @@ async def test_monitoring(self): @async_client_context.require_version_min(5, 0, -1) @async_client_context.require_no_mongos + @async_client_context.require_sync async def test_exhaust_cursor_db_set(self): listener = OvertCommandListener() client = await self.async_rs_or_single_client(event_listeners=[listener]) @@ -1821,7 +1822,7 @@ async def test_exhaust_cursor_db_set(self): listener.reset() - result = await c.find({}, cursor_type=pymongo.CursorType.EXHAUST, batch_size=1).to_list() + result = list(await c.find({}, cursor_type=pymongo.CursorType.EXHAUST, batch_size=1)) self.assertEqual(len(result), 3) diff --git a/test/asynchronous/test_custom_types.py b/test/asynchronous/test_custom_types.py new file mode 100644 index 0000000000..0f9d737afe --- /dev/null +++ b/test/asynchronous/test_custom_types.py @@ -0,0 +1,989 @@ +# Copyright 2019-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test support for callbacks to encode/decode custom types.""" +from __future__ import annotations + +import datetime +import sys +import tempfile +from collections import OrderedDict +from decimal import Decimal +from random import random +from typing import Any, Tuple, Type, no_type_check + +from gridfs.asynchronous.grid_file import AsyncGridIn, AsyncGridOut + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest + +from bson import ( + _BUILT_IN_TYPES, + RE_TYPE, + Decimal128, + _bson_to_dict, + _dict_to_bson, + decode, + decode_all, + decode_file_iter, + decode_iter, + encode, +) +from bson.codec_options import ( + CodecOptions, + TypeCodec, + TypeDecoder, + TypeEncoder, + TypeRegistry, +) +from bson.errors import InvalidDocument +from bson.int64 import Int64 +from bson.raw_bson import RawBSONDocument +from pymongo.asynchronous.collection import ReturnDocument +from pymongo.asynchronous.helpers import anext +from pymongo.errors import DuplicateKeyError +from pymongo.message import _CursorAddress + +_IS_SYNC = False + + +class DecimalEncoder(TypeEncoder): + @property + def python_type(self): + return Decimal + + def transform_python(self, value): + return Decimal128(value) + + +class DecimalDecoder(TypeDecoder): + @property + def bson_type(self): + return Decimal128 + + def transform_bson(self, value): + return value.to_decimal() + + +class DecimalCodec(DecimalDecoder, DecimalEncoder): + pass + + +DECIMAL_CODECOPTS = CodecOptions(type_registry=TypeRegistry([DecimalCodec()])) + + +class UndecipherableInt64Type: + def __init__(self, value): + self.value = value + + def __eq__(self, other): + if isinstance(other, type(self)): + return self.value == other.value + # Does not compare equal to integers. + return False + + +class UndecipherableIntDecoder(TypeDecoder): + bson_type = Int64 + + def transform_bson(self, value): + return UndecipherableInt64Type(value) + + +class UndecipherableIntEncoder(TypeEncoder): + python_type = UndecipherableInt64Type + + def transform_python(self, value): + return Int64(value.value) + + +UNINT_DECODER_CODECOPTS = CodecOptions( + type_registry=TypeRegistry( + [ + UndecipherableIntDecoder(), + ] + ) +) + + +UNINT_CODECOPTS = CodecOptions( + type_registry=TypeRegistry([UndecipherableIntDecoder(), UndecipherableIntEncoder()]) +) + + +class UppercaseTextDecoder(TypeDecoder): + bson_type = str + + def transform_bson(self, value): + return value.upper() + + +UPPERSTR_DECODER_CODECOPTS = CodecOptions( + type_registry=TypeRegistry( + [ + UppercaseTextDecoder(), + ] + ) +) + + +def type_obfuscating_decoder_factory(rt_type): + class ResumeTokenToNanDecoder(TypeDecoder): + bson_type = rt_type + + def transform_bson(self, value): + return "NaN" + + return ResumeTokenToNanDecoder + + +class CustomBSONTypeTests: + @no_type_check + def roundtrip(self, doc): + bsonbytes = encode(doc, codec_options=self.codecopts) + rt_document = decode(bsonbytes, codec_options=self.codecopts) + self.assertEqual(doc, rt_document) + + def test_encode_decode_roundtrip(self): + self.roundtrip({"average": Decimal("56.47")}) + self.roundtrip({"average": {"b": Decimal("56.47")}}) + self.roundtrip({"average": [Decimal("56.47")]}) + self.roundtrip({"average": [[Decimal("56.47")]]}) + self.roundtrip({"average": [{"b": Decimal("56.47")}]}) + + @no_type_check + def test_decode_all(self): + documents = [] + for dec in range(3): + documents.append({"average": Decimal(f"56.4{dec}")}) + + bsonstream = b"" + for doc in documents: + bsonstream += encode(doc, codec_options=self.codecopts) + + self.assertEqual(decode_all(bsonstream, self.codecopts), documents) + + @no_type_check + def test__bson_to_dict(self): + document = {"average": Decimal("56.47")} + rawbytes = encode(document, codec_options=self.codecopts) + decoded_document = _bson_to_dict(rawbytes, self.codecopts) + self.assertEqual(document, decoded_document) + + @no_type_check + def test__dict_to_bson(self): + document = {"average": Decimal("56.47")} + rawbytes = encode(document, codec_options=self.codecopts) + encoded_document = _dict_to_bson(document, False, self.codecopts) + self.assertEqual(encoded_document, rawbytes) + + def _generate_multidocument_bson_stream(self): + inp_num = [str(random() * 100)[:4] for _ in range(10)] + docs = [{"n": Decimal128(dec)} for dec in inp_num] + edocs = [{"n": Decimal(dec)} for dec in inp_num] + bsonstream = b"" + for doc in docs: + bsonstream += encode(doc) + return edocs, bsonstream + + @no_type_check + def test_decode_iter(self): + expected, bson_data = self._generate_multidocument_bson_stream() + for expected_doc, decoded_doc in zip(expected, decode_iter(bson_data, self.codecopts)): + self.assertEqual(expected_doc, decoded_doc) + + @no_type_check + def test_decode_file_iter(self): + expected, bson_data = self._generate_multidocument_bson_stream() + fileobj = tempfile.TemporaryFile() + fileobj.write(bson_data) + fileobj.seek(0) + + for expected_doc, decoded_doc in zip(expected, decode_file_iter(fileobj, self.codecopts)): + self.assertEqual(expected_doc, decoded_doc) + + fileobj.close() + + +class TestCustomPythonBSONTypeToBSONMonolithicCodec(CustomBSONTypeTests, unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.codecopts = DECIMAL_CODECOPTS + + +class TestCustomPythonBSONTypeToBSONMultiplexedCodec(CustomBSONTypeTests, unittest.TestCase): + @classmethod + def setUpClass(cls): + codec_options = CodecOptions( + type_registry=TypeRegistry((DecimalEncoder(), DecimalDecoder())) + ) + cls.codecopts = codec_options + + +class TestBSONFallbackEncoder(unittest.TestCase): + def _get_codec_options(self, fallback_encoder): + type_registry = TypeRegistry(fallback_encoder=fallback_encoder) + return CodecOptions(type_registry=type_registry) + + def test_simple(self): + codecopts = self._get_codec_options(lambda x: Decimal128(x)) + document = {"average": Decimal("56.47")} + bsonbytes = encode(document, codec_options=codecopts) + + exp_document = {"average": Decimal128("56.47")} + exp_bsonbytes = encode(exp_document) + self.assertEqual(bsonbytes, exp_bsonbytes) + + def test_erroring_fallback_encoder(self): + codecopts = self._get_codec_options(lambda _: 1 / 0) + + # fallback converter should not be invoked when encoding known types. + encode( + {"a": 1, "b": Decimal128("1.01"), "c": {"arr": ["abc", 3.678]}}, codec_options=codecopts + ) + + # expect an error when encoding a custom type. + document = {"average": Decimal("56.47")} + with self.assertRaises(ZeroDivisionError): + encode(document, codec_options=codecopts) + + def test_noop_fallback_encoder(self): + codecopts = self._get_codec_options(lambda x: x) + document = {"average": Decimal("56.47")} + with self.assertRaises(InvalidDocument): + encode(document, codec_options=codecopts) + + def test_type_unencodable_by_fallback_encoder(self): + def fallback_encoder(value): + try: + return Decimal128(value) + except: + raise TypeError("cannot encode type %s" % (type(value))) + + codecopts = self._get_codec_options(fallback_encoder) + document = {"average": Decimal} + with self.assertRaises(TypeError): + encode(document, codec_options=codecopts) + + def test_call_only_once_for_not_handled_big_integers(self): + called_with = [] + + def fallback_encoder(value): + called_with.append(value) + return value + + codecopts = self._get_codec_options(fallback_encoder) + document = {"a": {"b": {"c": 2 << 65}}} + + msg = "MongoDB can only handle up to 8-byte ints" + with self.assertRaises(OverflowError, msg=msg): + encode(document, codec_options=codecopts) + + self.assertEqual(called_with, [2 << 65]) + + +class TestBSONTypeEnDeCodecs(unittest.TestCase): + def test_instantiation(self): + msg = "Can't instantiate abstract class" + + def run_test(base, attrs, fail): + codec = type("testcodec", (base,), attrs) + if fail: + with self.assertRaisesRegex(TypeError, msg): + codec() + else: + codec() + + class MyType: + pass + + run_test( + TypeEncoder, + { + "python_type": MyType, + }, + fail=True, + ) + run_test(TypeEncoder, {"transform_python": lambda s, x: x}, fail=True) + run_test( + TypeEncoder, {"transform_python": lambda s, x: x, "python_type": MyType}, fail=False + ) + + run_test( + TypeDecoder, + { + "bson_type": Decimal128, + }, + fail=True, + ) + run_test(TypeDecoder, {"transform_bson": lambda s, x: x}, fail=True) + run_test( + TypeDecoder, {"transform_bson": lambda s, x: x, "bson_type": Decimal128}, fail=False + ) + + run_test(TypeCodec, {"bson_type": Decimal128, "python_type": MyType}, fail=True) + run_test( + TypeCodec, + {"transform_bson": lambda s, x: x, "transform_python": lambda s, x: x}, + fail=True, + ) + run_test( + TypeCodec, + { + "python_type": MyType, + "transform_python": lambda s, x: x, + "transform_bson": lambda s, x: x, + "bson_type": Decimal128, + }, + fail=False, + ) + + def test_type_checks(self): + self.assertTrue(issubclass(TypeCodec, TypeEncoder)) + self.assertTrue(issubclass(TypeCodec, TypeDecoder)) + self.assertFalse(issubclass(TypeDecoder, TypeEncoder)) + self.assertFalse(issubclass(TypeEncoder, TypeDecoder)) + + +class TestBSONCustomTypeEncoderAndFallbackEncoderTandem(unittest.TestCase): + TypeA: Any + TypeB: Any + fallback_encoder_A2B: Any + fallback_encoder_A2BSON: Any + B2BSON: Type[TypeEncoder] + B2A: Type[TypeEncoder] + A2B: Type[TypeEncoder] + + @classmethod + def setUpClass(cls): + class TypeA: + def __init__(self, x): + self.value = x + + class TypeB: + def __init__(self, x): + self.value = x + + # transforms A, and only A into B + def fallback_encoder_A2B(value): + assert isinstance(value, TypeA) + return TypeB(value.value) + + # transforms A, and only A into something encodable + def fallback_encoder_A2BSON(value): + assert isinstance(value, TypeA) + return value.value + + # transforms B into something encodable + class B2BSON(TypeEncoder): + python_type = TypeB + + def transform_python(self, value): + return value.value + + # transforms A into B + # technically, this isn't a proper type encoder as the output is not + # BSON-encodable. + class A2B(TypeEncoder): + python_type = TypeA + + def transform_python(self, value): + return TypeB(value.value) + + # transforms B into A + # technically, this isn't a proper type encoder as the output is not + # BSON-encodable. + class B2A(TypeEncoder): + python_type = TypeB + + def transform_python(self, value): + return TypeA(value.value) + + cls.TypeA = TypeA + cls.TypeB = TypeB + cls.fallback_encoder_A2B = staticmethod(fallback_encoder_A2B) + cls.fallback_encoder_A2BSON = staticmethod(fallback_encoder_A2BSON) + cls.B2BSON = B2BSON + cls.B2A = B2A + cls.A2B = A2B + + def test_encode_fallback_then_custom(self): + codecopts = CodecOptions( + type_registry=TypeRegistry([self.B2BSON()], fallback_encoder=self.fallback_encoder_A2B) + ) + testdoc = {"x": self.TypeA(123)} + expected_bytes = encode({"x": 123}) + + self.assertEqual(encode(testdoc, codec_options=codecopts), expected_bytes) + + def test_encode_custom_then_fallback(self): + codecopts = CodecOptions( + type_registry=TypeRegistry([self.B2A()], fallback_encoder=self.fallback_encoder_A2BSON) + ) + testdoc = {"x": self.TypeB(123)} + expected_bytes = encode({"x": 123}) + + self.assertEqual(encode(testdoc, codec_options=codecopts), expected_bytes) + + def test_chaining_encoders_fails(self): + codecopts = CodecOptions(type_registry=TypeRegistry([self.A2B(), self.B2BSON()])) + + with self.assertRaises(InvalidDocument): + encode({"x": self.TypeA(123)}, codec_options=codecopts) + + def test_infinite_loop_exceeds_max_recursion_depth(self): + codecopts = CodecOptions( + type_registry=TypeRegistry([self.B2A()], fallback_encoder=self.fallback_encoder_A2B) + ) + + # Raises max recursion depth exceeded error + with self.assertRaises(RuntimeError): + encode({"x": self.TypeA(100)}, codec_options=codecopts) + + +class TestTypeRegistry(unittest.TestCase): + types: Tuple[object, object] + codecs: Tuple[Type[TypeCodec], Type[TypeCodec]] + fallback_encoder: Any + + @classmethod + def setUpClass(cls): + class MyIntType: + def __init__(self, x): + assert isinstance(x, int) + self.x = x + + class MyStrType: + def __init__(self, x): + assert isinstance(x, str) + self.x = x + + class MyIntCodec(TypeCodec): + @property + def python_type(self): + return MyIntType + + @property + def bson_type(self): + return int + + def transform_python(self, value): + return value.x + + def transform_bson(self, value): + return MyIntType(value) + + class MyStrCodec(TypeCodec): + @property + def python_type(self): + return MyStrType + + @property + def bson_type(self): + return str + + def transform_python(self, value): + return value.x + + def transform_bson(self, value): + return MyStrType(value) + + def fallback_encoder(value): + return value + + cls.types = (MyIntType, MyStrType) + cls.codecs = (MyIntCodec, MyStrCodec) + cls.fallback_encoder = fallback_encoder + + def test_simple(self): + codec_instances = [codec() for codec in self.codecs] + + def assert_proper_initialization(type_registry, codec_instances): + self.assertEqual( + type_registry._encoder_map, + { + self.types[0]: codec_instances[0].transform_python, + self.types[1]: codec_instances[1].transform_python, + }, + ) + self.assertEqual( + type_registry._decoder_map, + {int: codec_instances[0].transform_bson, str: codec_instances[1].transform_bson}, + ) + self.assertEqual(type_registry._fallback_encoder, self.fallback_encoder) + + type_registry = TypeRegistry(codec_instances, self.fallback_encoder) + assert_proper_initialization(type_registry, codec_instances) + + type_registry = TypeRegistry( + fallback_encoder=self.fallback_encoder, type_codecs=codec_instances + ) + assert_proper_initialization(type_registry, codec_instances) + + # Ensure codec list held by the type registry doesn't change if we + # mutate the initial list. + codec_instances_copy = list(codec_instances) + codec_instances.pop(0) + self.assertListEqual(type_registry._TypeRegistry__type_codecs, codec_instances_copy) + + def test_simple_separate_codecs(self): + class MyIntEncoder(TypeEncoder): + python_type = self.types[0] + + def transform_python(self, value): + return value.x + + class MyIntDecoder(TypeDecoder): + bson_type = int + + def transform_bson(self, value): + return self.types[0](value) + + codec_instances: list = [MyIntDecoder(), MyIntEncoder()] + type_registry = TypeRegistry(codec_instances) + + self.assertEqual( + type_registry._encoder_map, + {MyIntEncoder.python_type: codec_instances[1].transform_python}, + ) + self.assertEqual( + type_registry._decoder_map, + {MyIntDecoder.bson_type: codec_instances[0].transform_bson}, + ) + + def test_initialize_fail(self): + err_msg = "Expected an instance of TypeEncoder, TypeDecoder, or TypeCodec, got .* instead" + with self.assertRaisesRegex(TypeError, err_msg): + TypeRegistry(self.codecs) # type: ignore[arg-type] + + with self.assertRaisesRegex(TypeError, err_msg): + TypeRegistry([type("AnyType", (object,), {})()]) + + err_msg = f"fallback_encoder {True!r} is not a callable" + with self.assertRaisesRegex(TypeError, err_msg): + TypeRegistry([], True) # type: ignore[arg-type] + + err_msg = "fallback_encoder {!r} is not a callable".format("hello") + with self.assertRaisesRegex(TypeError, err_msg): + TypeRegistry(fallback_encoder="hello") # type: ignore[arg-type] + + def test_type_registry_repr(self): + codec_instances = [codec() for codec in self.codecs] + type_registry = TypeRegistry(codec_instances) + r = f"TypeRegistry(type_codecs={codec_instances!r}, fallback_encoder={None!r})" + self.assertEqual(r, repr(type_registry)) + + def test_type_registry_eq(self): + codec_instances = [codec() for codec in self.codecs] + self.assertEqual(TypeRegistry(codec_instances), TypeRegistry(codec_instances)) + + codec_instances_2 = [codec() for codec in self.codecs] + self.assertNotEqual(TypeRegistry(codec_instances), TypeRegistry(codec_instances_2)) + + def test_builtin_types_override_fails(self): + def run_test(base, attrs): + msg = ( + r"TypeEncoders cannot change how built-in types " + r"are encoded \(encoder .* transforms type .*\)" + ) + for pytype in _BUILT_IN_TYPES: + attrs.update({"python_type": pytype, "transform_python": lambda x: x}) + codec = type("testcodec", (base,), attrs) + codec_instance = codec() + with self.assertRaisesRegex(TypeError, msg): + TypeRegistry( + [ + codec_instance, + ] + ) + + # Test only some subtypes as not all can be subclassed. + if pytype in [ + bool, + type(None), + RE_TYPE, + ]: + continue + + class MyType(pytype): # type: ignore + pass + + attrs.update({"python_type": MyType, "transform_python": lambda x: x}) + codec = type("testcodec", (base,), attrs) + codec_instance = codec() + with self.assertRaisesRegex(TypeError, msg): + TypeRegistry( + [ + codec_instance, + ] + ) + + run_test(TypeEncoder, {}) + run_test(TypeCodec, {"bson_type": Decimal128, "transform_bson": lambda x: x}) + + +class TestCollectionWCustomType(AsyncIntegrationTest): + async def asyncSetUp(self): + await super().asyncSetUp() + await self.db.test.drop() + + async def asyncTearDown(self): + await self.db.test.drop() + + async def test_overflow_int_w_custom_decoder(self): + type_registry = TypeRegistry(fallback_encoder=lambda val: str(val)) + codec_options = CodecOptions(type_registry=type_registry) + collection = self.db.get_collection("test", codec_options=codec_options) + + await collection.insert_one({"_id": 1, "data": 2**520}) + ret = await collection.find_one() + self.assertEqual(ret["data"], str(2**520)) + + async def test_command_errors_w_custom_type_decoder(self): + db = self.db + test_doc = {"_id": 1, "data": "a"} + test = db.get_collection("test", codec_options=UNINT_DECODER_CODECOPTS) + + result = await test.insert_one(test_doc) + self.assertEqual(result.inserted_id, test_doc["_id"]) + with self.assertRaises(DuplicateKeyError): + await test.insert_one(test_doc) + + async def test_find_w_custom_type_decoder(self): + db = self.db + input_docs = [{"x": Int64(k)} for k in [1, 2, 3]] + for doc in input_docs: + await db.test.insert_one(doc) + + test = db.get_collection("test", codec_options=UNINT_DECODER_CODECOPTS) + async for doc in test.find({}, batch_size=1): + self.assertIsInstance(doc["x"], UndecipherableInt64Type) + + async def test_find_w_custom_type_decoder_and_document_class(self): + async def run_test(doc_cls): + db = self.db + input_docs = [{"x": Int64(k)} for k in [1, 2, 3]] + for doc in input_docs: + await db.test.insert_one(doc) + + test = db.get_collection( + "test", + codec_options=CodecOptions( + type_registry=TypeRegistry([UndecipherableIntDecoder()]), document_class=doc_cls + ), + ) + async for doc in test.find({}, batch_size=1): + self.assertIsInstance(doc, doc_cls) + self.assertIsInstance(doc["x"], UndecipherableInt64Type) + + for doc_cls in [RawBSONDocument, OrderedDict]: + await run_test(doc_cls) + + async def test_aggregate_w_custom_type_decoder(self): + db = self.db + await db.test.insert_many( + [ + {"status": "in progress", "qty": Int64(1)}, + {"status": "complete", "qty": Int64(10)}, + {"status": "in progress", "qty": Int64(1)}, + {"status": "complete", "qty": Int64(10)}, + {"status": "in progress", "qty": Int64(1)}, + ] + ) + test = db.get_collection("test", codec_options=UNINT_DECODER_CODECOPTS) + + pipeline: list = [ + {"$match": {"status": "complete"}}, + {"$group": {"_id": "$status", "total_qty": {"$sum": "$qty"}}}, + ] + result = await test.aggregate(pipeline) + + res = (await result.to_list())[0] + self.assertEqual(res["_id"], "complete") + self.assertIsInstance(res["total_qty"], UndecipherableInt64Type) + self.assertEqual(res["total_qty"].value, 20) + + async def test_distinct_w_custom_type(self): + await self.db.drop_collection("test") + + test = self.db.get_collection("test", codec_options=UNINT_CODECOPTS) + values = [ + UndecipherableInt64Type(1), + UndecipherableInt64Type(2), + UndecipherableInt64Type(3), + {"b": UndecipherableInt64Type(3)}, + ] + await test.insert_many({"a": val} for val in values) + + self.assertEqual(values, await test.distinct("a")) + + async def test_find_one_and__w_custom_type_decoder(self): + db = self.db + c = db.get_collection("test", codec_options=UNINT_DECODER_CODECOPTS) + await c.insert_one({"_id": 1, "x": Int64(1)}) + + doc = await c.find_one_and_update( + {"_id": 1}, {"$inc": {"x": 1}}, return_document=ReturnDocument.AFTER + ) + self.assertEqual(doc["_id"], 1) + self.assertIsInstance(doc["x"], UndecipherableInt64Type) + self.assertEqual(doc["x"].value, 2) + + doc = await c.find_one_and_replace( + {"_id": 1}, {"x": Int64(3), "y": True}, return_document=ReturnDocument.AFTER + ) + self.assertEqual(doc["_id"], 1) + self.assertIsInstance(doc["x"], UndecipherableInt64Type) + self.assertEqual(doc["x"].value, 3) + self.assertEqual(doc["y"], True) + + doc = await c.find_one_and_delete({"y": True}) + self.assertEqual(doc["_id"], 1) + self.assertIsInstance(doc["x"], UndecipherableInt64Type) + self.assertEqual(doc["x"].value, 3) + self.assertIsNone(await c.find_one()) + + +class TestGridFileCustomType(AsyncIntegrationTest): + async def asyncSetUp(self): + await super().asyncSetUp() + await self.db.drop_collection("fs.files") + await self.db.drop_collection("fs.chunks") + + async def test_grid_out_custom_opts(self): + db = self.db.with_options(codec_options=UPPERSTR_DECODER_CODECOPTS) + one = AsyncGridIn( + db.fs, + _id=5, + filename="my_file", + chunkSize=1000, + metadata={"foo": "red", "bar": "blue"}, + bar=3, + baz="hello", + ) + await one.write(b"hello world") + await one.close() + + two = AsyncGridOut(db.fs, 5) + await two.open() + + self.assertEqual("my_file", two.name) + self.assertEqual("my_file", two.filename) + self.assertEqual(5, two._id) + self.assertEqual(11, two.length) + self.assertEqual(1000, two.chunk_size) + self.assertTrue(isinstance(two.upload_date, datetime.datetime)) + self.assertEqual({"foo": "red", "bar": "blue"}, two.metadata) + self.assertEqual(3, two.bar) + + for attr in [ + "_id", + "name", + "content_type", + "length", + "chunk_size", + "upload_date", + "aliases", + "metadata", + "md5", + ]: + self.assertRaises(AttributeError, setattr, two, attr, 5) + + +class ChangeStreamsWCustomTypesTestMixin: + @no_type_check + async def change_stream(self, *args, **kwargs): + stream = await self.watched_target.watch(*args, max_await_time_ms=1, **kwargs) + self.addAsyncCleanup(stream.close) + return stream + + @no_type_check + async def insert_and_check(self, change_stream, insert_doc, expected_doc): + await self.input_target.insert_one(insert_doc) + change = await anext(change_stream) + self.assertEqual(change["fullDocument"], expected_doc) + + @no_type_check + async def kill_change_stream_cursor(self, change_stream): + # Cause a cursor not found error on the next getMore. + cursor = change_stream._cursor + address = _CursorAddress(cursor.address, cursor._ns) + client = self.input_target.database.client + await client._close_cursor_now(cursor.cursor_id, address) + + @no_type_check + async def test_simple(self): + codecopts = CodecOptions( + type_registry=TypeRegistry([UndecipherableIntEncoder(), UppercaseTextDecoder()]) + ) + await self.create_targets(codec_options=codecopts) + + input_docs = [ + {"_id": UndecipherableInt64Type(1), "data": "hello"}, + {"_id": 2, "data": "world"}, + {"_id": UndecipherableInt64Type(3), "data": "!"}, + ] + expected_docs = [ + {"_id": 1, "data": "HELLO"}, + {"_id": 2, "data": "WORLD"}, + {"_id": 3, "data": "!"}, + ] + + change_stream = await self.change_stream() + + await self.insert_and_check(change_stream, input_docs[0], expected_docs[0]) + await self.kill_change_stream_cursor(change_stream) + await self.insert_and_check(change_stream, input_docs[1], expected_docs[1]) + await self.kill_change_stream_cursor(change_stream) + await self.insert_and_check(change_stream, input_docs[2], expected_docs[2]) + + @no_type_check + async def test_custom_type_in_pipeline(self): + codecopts = CodecOptions( + type_registry=TypeRegistry([UndecipherableIntEncoder(), UppercaseTextDecoder()]) + ) + await self.create_targets(codec_options=codecopts) + + input_docs = [ + {"_id": UndecipherableInt64Type(1), "data": "hello"}, + {"_id": 2, "data": "world"}, + {"_id": UndecipherableInt64Type(3), "data": "!"}, + ] + expected_docs = [{"_id": 2, "data": "WORLD"}, {"_id": 3, "data": "!"}] + + # UndecipherableInt64Type should be encoded with the TypeRegistry. + change_stream = await self.change_stream( + [{"$match": {"documentKey._id": {"$gte": UndecipherableInt64Type(2)}}}] + ) + + await self.input_target.insert_one(input_docs[0]) + await self.insert_and_check(change_stream, input_docs[1], expected_docs[0]) + await self.kill_change_stream_cursor(change_stream) + await self.insert_and_check(change_stream, input_docs[2], expected_docs[1]) + + @no_type_check + async def test_break_resume_token(self): + # Get one document from a change stream to determine resumeToken type. + await self.create_targets() + change_stream = await self.change_stream() + await self.input_target.insert_one({"data": "test"}) + change = await anext(change_stream) + resume_token_decoder = type_obfuscating_decoder_factory(type(change["_id"]["_data"])) + + # Custom-decoding the resumeToken type breaks resume tokens. + codecopts = CodecOptions( + type_registry=TypeRegistry([resume_token_decoder(), UndecipherableIntEncoder()]) + ) + + # Re-create targets, change stream and proceed. + await self.create_targets(codec_options=codecopts) + + docs = [{"_id": 1}, {"_id": 2}, {"_id": 3}] + + change_stream = await self.change_stream() + await self.insert_and_check(change_stream, docs[0], docs[0]) + await self.kill_change_stream_cursor(change_stream) + await self.insert_and_check(change_stream, docs[1], docs[1]) + await self.kill_change_stream_cursor(change_stream) + await self.insert_and_check(change_stream, docs[2], docs[2]) + + @no_type_check + async def test_document_class(self): + async def run_test(doc_cls): + codecopts = CodecOptions( + type_registry=TypeRegistry([UppercaseTextDecoder(), UndecipherableIntEncoder()]), + document_class=doc_cls, + ) + + await self.create_targets(codec_options=codecopts) + change_stream = await self.change_stream() + + doc = {"a": UndecipherableInt64Type(101), "b": "xyz"} + await self.input_target.insert_one(doc) + change = await anext(change_stream) + + self.assertIsInstance(change, doc_cls) + self.assertEqual(change["fullDocument"]["a"], 101) + self.assertEqual(change["fullDocument"]["b"], "XYZ") + + for doc_cls in [OrderedDict, RawBSONDocument]: + await run_test(doc_cls) + + +class TestCollectionChangeStreamsWCustomTypes( + AsyncIntegrationTest, ChangeStreamsWCustomTypesTestMixin +): + @async_client_context.require_change_streams + async def asyncSetUp(self): + await super().asyncSetUp() + await self.db.test.delete_many({}) + + async def asyncTearDown(self): + await self.input_target.drop() + + async def create_targets(self, *args, **kwargs): + self.watched_target = self.db.get_collection("test", *args, **kwargs) + self.input_target = self.watched_target + # Ensure the collection exists and is empty. + await self.input_target.insert_one({}) + await self.input_target.delete_many({}) + + +class TestDatabaseChangeStreamsWCustomTypes( + AsyncIntegrationTest, ChangeStreamsWCustomTypesTestMixin +): + @async_client_context.require_version_min(4, 0, 0) + @async_client_context.require_change_streams + async def asyncSetUp(self): + await super().asyncSetUp() + await self.db.test.delete_many({}) + + async def asyncTearDown(self): + await self.input_target.drop() + await self.client.drop_database(self.watched_target) + + async def create_targets(self, *args, **kwargs): + self.watched_target = self.client.get_database(self.db.name, *args, **kwargs) + self.input_target = self.watched_target.test + # Insert a record to ensure db, coll are created. + await self.input_target.insert_one({"data": "dummy"}) + + +class TestClusterChangeStreamsWCustomTypes( + AsyncIntegrationTest, ChangeStreamsWCustomTypesTestMixin +): + @async_client_context.require_version_min(4, 0, 0) + @async_client_context.require_change_streams + async def asyncSetUp(self): + await super().asyncSetUp() + await self.db.test.delete_many({}) + + async def asyncTearDown(self): + await self.input_target.drop() + await self.client.drop_database(self.db) + + async def create_targets(self, *args, **kwargs): + codec_options = kwargs.pop("codec_options", None) + if codec_options: + kwargs["type_registry"] = codec_options.type_registry + kwargs["document_class"] = codec_options.document_class + self.watched_target = await self.async_rs_client(*args, **kwargs) + self.input_target = self.watched_target[self.db.name].test + # Insert a record to ensure db, coll are created. + await self.input_target.insert_one({"data": "dummy"}) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_data_lake.py b/test/asynchronous/test_data_lake.py new file mode 100644 index 0000000000..689bf38534 --- /dev/null +++ b/test/asynchronous/test_data_lake.py @@ -0,0 +1,107 @@ +# Copyright 2020-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test Atlas Data Lake.""" +from __future__ import annotations + +import os +import sys +from pathlib import Path + +import pytest + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, AsyncUnitTest, async_client_context, unittest +from test.asynchronous.unified_format import generate_test_classes +from test.utils_shared import ( + OvertCommandListener, +) + +from pymongo.asynchronous.helpers import anext + +_IS_SYNC = False + +pytestmark = pytest.mark.data_lake + + +class TestDataLakeMustConnect(AsyncUnitTest): + async def test_connected_to_data_lake(self): + self.assertTrue( + async_client_context.is_data_lake and async_client_context.connected, + "client context must be connected to data lake when DATA_LAKE is set. Failed attempts:\n{}".format( + async_client_context.connection_attempt_info() + ), + ) + + +class TestDataLakeProse(AsyncIntegrationTest): + # Default test database and collection names. + TEST_DB = "test" + TEST_COLLECTION = "driverdata" + + @async_client_context.require_data_lake + async def asyncSetUp(self): + await super().asyncSetUp() + + # Test killCursors + async def test_1(self): + listener = OvertCommandListener() + client = await self.async_rs_or_single_client(event_listeners=[listener]) + cursor = client[self.TEST_DB][self.TEST_COLLECTION].find({}, batch_size=2) + await anext(cursor) + + # find command assertions + find_cmd = listener.succeeded_events[-1] + self.assertEqual(find_cmd.command_name, "find") + cursor_id = find_cmd.reply["cursor"]["id"] + cursor_ns = find_cmd.reply["cursor"]["ns"] + + # killCursors command assertions + await cursor.close() + started = listener.started_events[-1] + self.assertEqual(started.command_name, "killCursors") + succeeded = listener.succeeded_events[-1] + self.assertEqual(succeeded.command_name, "killCursors") + + self.assertIn(cursor_id, started.command["cursors"]) + target_ns = ".".join([started.command["$db"], started.command["killCursors"]]) + self.assertEqual(cursor_ns, target_ns) + + self.assertIn(cursor_id, succeeded.reply["cursorsKilled"]) + + # Test no auth + async def test_2(self): + client = await self.async_rs_client_noauth() + await client.admin.command("ping") + + # Test with auth + async def test_3(self): + for mechanism in ["SCRAM-SHA-1", "SCRAM-SHA-256"]: + client = await self.async_rs_or_single_client(authMechanism=mechanism) + await client[self.TEST_DB][self.TEST_COLLECTION].find_one() + + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = Path(__file__).parent / "data_lake/unified" +else: + TEST_PATH = Path(__file__).parent.parent / "data_lake/unified" + +# Generate unified tests. +globals().update(generate_test_classes(TEST_PATH, module=__name__)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_database.py b/test/asynchronous/test_database.py index 55a8cc3ab2..b2ddd4122d 100644 --- a/test/asynchronous/test_database.py +++ b/test/asynchronous/test_database.py @@ -26,7 +26,7 @@ from test import unittest from test.asynchronous import AsyncIntegrationTest, async_client_context from test.test_custom_types import DECIMAL_CODECOPTS -from test.utils import ( +from test.utils_shared import ( IMPOSSIBLE_WRITE_CONCERN, OvertCommandListener, async_wait_until, @@ -430,6 +430,21 @@ async def test_command_with_regex(self): for doc in result["cursor"]["firstBatch"]: self.assertTrue(isinstance(doc["r"], Regex)) + async def test_command_bulkWrite(self): + # Ensure bulk write commands can be run directly via db.command(). + if async_client_context.version.at_least(8, 0): + await self.client.admin.command( + { + "bulkWrite": 1, + "nsInfo": [{"ns": self.db.test.full_name}], + "ops": [{"insert": 0, "document": {}}], + } + ) + await self.db.command({"insert": "test", "documents": [{}]}) + await self.db.command({"update": "test", "updates": [{"q": {}, "u": {"$set": {"x": 1}}}]}) + await self.db.command({"delete": "test", "deletes": [{"q": {}, "limit": 1}]}) + await self.db.test.drop() + async def test_cursor_command(self): db = self.client.pymongo_test await db.test.drop() diff --git a/test/asynchronous/test_discovery_and_monitoring.py b/test/asynchronous/test_discovery_and_monitoring.py new file mode 100644 index 0000000000..cf26faf248 --- /dev/null +++ b/test/asynchronous/test_discovery_and_monitoring.py @@ -0,0 +1,586 @@ +# Copyright 2014-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the topology module.""" +from __future__ import annotations + +import asyncio +import os +import socketserver +import sys +import threading +import time +from asyncio import StreamReader, StreamWriter +from pathlib import Path +from test.asynchronous.helpers import ConcurrentRunner + +from pymongo.asynchronous.pool import AsyncConnection +from pymongo.operations import _Op +from pymongo.server_selectors import writable_server_selector + +sys.path[0:0] = [""] + +from test.asynchronous import ( + AsyncIntegrationTest, + AsyncPyMongoTestCase, + AsyncUnitTest, + async_client_context, + unittest, +) +from test.asynchronous.pymongo_mocks import DummyMonitor +from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.utils import ( + async_get_pool, +) +from test.utils_shared import ( + CMAPListener, + HeartbeatEventListener, + HeartbeatEventsListListener, + assertion_context, + async_barrier_wait, + async_create_barrier, + async_wait_until, + server_name_to_type, +) +from unittest.mock import patch + +from bson import Timestamp, json_util +from pymongo import common, monitoring +from pymongo.asynchronous.settings import TopologySettings +from pymongo.asynchronous.topology import Topology, _ErrorContext +from pymongo.asynchronous.uri_parser import parse_uri +from pymongo.errors import ( + AutoReconnect, + ConfigurationError, + NetworkTimeout, + NotPrimaryError, + OperationFailure, +) +from pymongo.hello import Hello, HelloCompat +from pymongo.helpers_shared import _check_command_response, _check_write_command_response +from pymongo.monitoring import ServerHeartbeatFailedEvent, ServerHeartbeatStartedEvent +from pymongo.server_description import SERVER_TYPE, ServerDescription +from pymongo.topology_description import TOPOLOGY_TYPE + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + SDAM_PATH = os.path.join(Path(__file__).resolve().parent, "discovery_and_monitoring") +else: + SDAM_PATH = os.path.join( + Path(__file__).resolve().parent.parent, + "discovery_and_monitoring", + ) + + +async def create_mock_topology(uri, monitor_class=DummyMonitor): + parsed_uri = await parse_uri(uri) + replica_set_name = None + direct_connection = None + load_balanced = None + if "replicaset" in parsed_uri["options"]: + replica_set_name = parsed_uri["options"]["replicaset"] + if "directConnection" in parsed_uri["options"]: + direct_connection = parsed_uri["options"]["directConnection"] + if "loadBalanced" in parsed_uri["options"]: + load_balanced = parsed_uri["options"]["loadBalanced"] + + topology_settings = TopologySettings( + parsed_uri["nodelist"], + replica_set_name=replica_set_name, + monitor_class=monitor_class, + direct_connection=direct_connection, + load_balanced=load_balanced, + ) + + c = Topology(topology_settings) + await c.open() + return c + + +async def got_hello(topology, server_address, hello_response): + server_description = ServerDescription(server_address, Hello(hello_response), 0) + await topology.on_change(server_description) + + +async def got_app_error(topology, app_error): + server_address = common.partition_node(app_error["address"]) + server = topology.get_server_by_address(server_address) + error_type = app_error["type"] + generation = app_error.get("generation", server.pool.gen.get_overall()) + when = app_error["when"] + max_wire_version = app_error["maxWireVersion"] + # XXX: We could get better test coverage by mocking the errors on the + # Pool/AsyncConnection. + try: + if error_type == "command": + _check_command_response(app_error["response"], max_wire_version) + _check_write_command_response(app_error["response"]) + elif error_type == "network": + raise AutoReconnect("mock non-timeout network error") + elif error_type == "timeout": + raise NetworkTimeout("mock network timeout error") + else: + raise AssertionError(f"unknown error type: {error_type}") + raise AssertionError + except (AutoReconnect, NotPrimaryError, OperationFailure) as e: + if when == "beforeHandshakeCompletes": + completed_handshake = False + elif when == "afterHandshakeCompletes": + completed_handshake = True + else: + raise AssertionError(f"Unknown when field {when}") + + await topology.handle_error( + server_address, + _ErrorContext(e, max_wire_version, generation, completed_handshake, None), + ) + + +def get_type(topology, hostname): + description = topology.get_server_by_address((hostname, 27017)).description + return description.server_type + + +class TestAllScenarios(AsyncUnitTest): + pass + + +def topology_type_name(topology_type): + return TOPOLOGY_TYPE._fields[topology_type] + + +def server_type_name(server_type): + return SERVER_TYPE._fields[server_type] + + +def check_outcome(self, topology, outcome): + expected_servers = outcome["servers"] + + # Check weak equality before proceeding. + self.assertEqual(len(topology.description.server_descriptions()), len(expected_servers)) + + if outcome.get("compatible") is False: + with self.assertRaises(ConfigurationError): + topology.description.check_compatible() + else: + # No error. + topology.description.check_compatible() + + # Since lengths are equal, every actual server must have a corresponding + # expected server. + for expected_server_address, expected_server in expected_servers.items(): + node = common.partition_node(expected_server_address) + self.assertTrue(topology.has_server(node)) + actual_server = topology.get_server_by_address(node) + actual_server_description = actual_server.description + expected_server_type = server_name_to_type(expected_server["type"]) + + self.assertEqual( + server_type_name(expected_server_type), + server_type_name(actual_server_description.server_type), + ) + expected_error = expected_server.get("error") + if expected_error: + self.assertIn(expected_error, str(actual_server_description.error)) + + self.assertEqual(expected_server.get("setName"), actual_server_description.replica_set_name) + + self.assertEqual(expected_server.get("setVersion"), actual_server_description.set_version) + + self.assertEqual(expected_server.get("electionId"), actual_server_description.election_id) + + self.assertEqual( + expected_server.get("topologyVersion"), actual_server_description.topology_version + ) + + expected_pool = expected_server.get("pool") + if expected_pool: + self.assertEqual(expected_pool.get("generation"), actual_server.pool.gen.get_overall()) + + self.assertEqual(outcome["setName"], topology.description.replica_set_name) + self.assertEqual( + outcome.get("logicalSessionTimeoutMinutes"), + topology.description.logical_session_timeout_minutes, + ) + + expected_topology_type = getattr(TOPOLOGY_TYPE, outcome["topologyType"]) + self.assertEqual( + topology_type_name(expected_topology_type), + topology_type_name(topology.description.topology_type), + ) + + self.assertEqual(outcome.get("maxSetVersion"), topology.description.max_set_version) + self.assertEqual(outcome.get("maxElectionId"), topology.description.max_election_id) + + +def create_test(scenario_def): + async def run_scenario(self): + c = await create_mock_topology(scenario_def["uri"]) + + for i, phase in enumerate(scenario_def["phases"]): + # Including the phase description makes failures easier to debug. + description = phase.get("description", str(i)) + with assertion_context(f"phase: {description}"): + for response in phase.get("responses", []): + await got_hello(c, common.partition_node(response[0]), response[1]) + + for app_error in phase.get("applicationErrors", []): + await got_app_error(c, app_error) + + check_outcome(self, c, phase["outcome"]) + + return run_scenario + + +def create_tests(): + for dirpath, _, filenames in os.walk(SDAM_PATH): + dirname = os.path.split(dirpath)[-1] + # SDAM unified tests are handled separately. + if dirname == "unified": + continue + + for filename in filenames: + if os.path.splitext(filename)[1] != ".json": + continue + with open(os.path.join(dirpath, filename)) as scenario_stream: + scenario_def = json_util.loads(scenario_stream.read()) + + # Construct test from scenario. + new_test = create_test(scenario_def) + test_name = f"test_{dirname}_{os.path.splitext(filename)[0]}" + + new_test.__name__ = test_name + setattr(TestAllScenarios, new_test.__name__, new_test) + + +create_tests() + + +class TestClusterTimeComparison(AsyncPyMongoTestCase): + async def test_cluster_time_comparison(self): + t = await create_mock_topology("mongodb://host") + + async def send_cluster_time(time, inc): + old = t.max_cluster_time() + new = {"clusterTime": Timestamp(time, inc)} + await got_hello( + t, + ("host", 27017), + { + "ok": 1, + "minWireVersion": 0, + "maxWireVersion": common.MIN_SUPPORTED_WIRE_VERSION, + "$clusterTime": new, + }, + ) + + actual = t.max_cluster_time() + # We never update $clusterTime from monitoring connections. + self.assertEqual(actual, old) + + await send_cluster_time(0, 1) + await send_cluster_time(2, 2) + await send_cluster_time(2, 1) + await send_cluster_time(1, 3) + await send_cluster_time(2, 3) + + +class TestIgnoreStaleErrors(AsyncIntegrationTest): + async def test_ignore_stale_connection_errors(self): + if not _IS_SYNC and sys.version_info < (3, 11): + self.skipTest("Test requires asyncio.Barrier (added in Python 3.11)") + N_TASKS = 5 + barrier = async_create_barrier(N_TASKS) + client = await self.async_rs_or_single_client(minPoolSize=N_TASKS) + + # Wait for initial discovery. + await client.admin.command("ping") + pool = await async_get_pool(client) + starting_generation = pool.gen.get_overall() + await async_wait_until(lambda: len(pool.conns) == N_TASKS, "created conns") + + async def mock_command(*args, **kwargs): + # Synchronize all tasks to ensure they use the same generation. + await async_barrier_wait(barrier, timeout=30) + raise AutoReconnect("mock AsyncConnection.command error") + + for conn in pool.conns: + conn.command = mock_command + + async def insert_command(i): + try: + await client.test.command("insert", "test", documents=[{"i": i}]) + except AutoReconnect: + pass + + tasks = [] + for i in range(N_TASKS): + tasks.append(ConcurrentRunner(target=insert_command, args=(i,))) + for t in tasks: + await t.start() + for t in tasks: + await t.join() + + # Expect a single pool reset for the network error + self.assertEqual(starting_generation + 1, pool.gen.get_overall()) + + # Server should be selectable. + await client.admin.command("ping") + + +class CMAPHeartbeatListener(HeartbeatEventListener, CMAPListener): + pass + + +class TestPoolManagement(AsyncIntegrationTest): + @async_client_context.require_failCommand_appName + async def test_pool_unpause(self): + # This test implements the prose test "AsyncConnection Pool Management" + listener = CMAPHeartbeatListener() + _ = await self.async_single_client( + appName="SDAMPoolManagementTest", heartbeatFrequencyMS=500, event_listeners=[listener] + ) + # Assert that AsyncConnectionPoolReadyEvent occurs after the first + # ServerHeartbeatSucceededEvent. + await listener.async_wait_for_event(monitoring.PoolReadyEvent, 1) + pool_ready = listener.events_by_type(monitoring.PoolReadyEvent)[0] + hb_succeeded = listener.events_by_type(monitoring.ServerHeartbeatSucceededEvent)[0] + self.assertGreater(listener.events.index(pool_ready), listener.events.index(hb_succeeded)) + + listener.reset() + fail_hello = { + "mode": {"times": 2}, + "data": { + "failCommands": [HelloCompat.LEGACY_CMD, "hello"], + "errorCode": 1234, + "appName": "SDAMPoolManagementTest", + }, + } + async with self.fail_point(fail_hello): + await listener.async_wait_for_event(monitoring.ServerHeartbeatFailedEvent, 1) + await listener.async_wait_for_event(monitoring.PoolClearedEvent, 1) + await listener.async_wait_for_event(monitoring.ServerHeartbeatSucceededEvent, 1) + await listener.async_wait_for_event(monitoring.PoolReadyEvent, 1) + + @async_client_context.require_failCommand_appName + @async_client_context.require_test_commands + @async_client_context.require_async + async def test_connection_close_does_not_block_other_operations(self): + listener = CMAPHeartbeatListener() + client = await self.async_single_client( + appName="SDAMConnectionCloseTest", + event_listeners=[listener], + heartbeatFrequencyMS=500, + minPoolSize=10, + ) + server = await (await client._get_topology()).select_server( + writable_server_selector, _Op.TEST + ) + await async_wait_until( + lambda: len(server._pool.conns) == 10, + "pool initialized with 10 connections", + ) + + await client.db.test.insert_one({"x": 1}) + close_delay = 0.1 + latencies = [] + should_exit = [] + + async def run_task(): + while True: + start_time = time.monotonic() + await client.db.test.find_one({}) + elapsed = time.monotonic() - start_time + latencies.append(elapsed) + if should_exit: + break + await asyncio.sleep(0.001) + + task = ConcurrentRunner(target=run_task) + await task.start() + original_close = AsyncConnection.close_conn + try: + # Artificially delay the close operation to simulate a slow close + async def mock_close(self, reason): + await asyncio.sleep(close_delay) + await original_close(self, reason) + + AsyncConnection.close_conn = mock_close + + fail_hello = { + "mode": {"times": 4}, + "data": { + "failCommands": [HelloCompat.LEGACY_CMD, "hello"], + "errorCode": 91, + "appName": "SDAMConnectionCloseTest", + }, + } + async with self.fail_point(fail_hello): + # Wait for server heartbeat to fail + await listener.async_wait_for_event(monitoring.ServerHeartbeatFailedEvent, 1) + # Wait until all idle connections are closed to simulate real-world conditions + await listener.async_wait_for_event(monitoring.ConnectionClosedEvent, 10) + # Wait for one more find to complete after the pool has been reset, then shutdown the task + n = len(latencies) + await async_wait_until(lambda: len(latencies) >= n + 1, "run one more find") + should_exit.append(True) + await task.join() + # No operation latency should not significantly exceed close_delay + self.assertLessEqual(max(latencies), close_delay * 5.0) + finally: + AsyncConnection.close_conn = original_close + + +class TestServerMonitoringMode(AsyncIntegrationTest): + @async_client_context.require_no_serverless + @async_client_context.require_no_load_balancer + async def asyncSetUp(self): + await super().asyncSetUp() + + async def test_rtt_connection_is_enabled_stream(self): + client = await self.async_rs_or_single_client(serverMonitoringMode="stream") + await client.admin.command("ping") + + def predicate(): + for _, server in client._topology._servers.items(): + monitor = server._monitor + if not monitor._stream: + return False + if async_client_context.version >= (4, 4): + if _IS_SYNC: + if monitor._rtt_monitor._executor._thread is None: + return False + else: + if monitor._rtt_monitor._executor._task is None: + return False + else: + if _IS_SYNC: + if monitor._rtt_monitor._executor._thread is not None: + return False + else: + if monitor._rtt_monitor._executor._task is not None: + return False + return True + + await async_wait_until(predicate, "find all RTT monitors") + + async def test_rtt_connection_is_disabled_poll(self): + client = await self.async_rs_or_single_client(serverMonitoringMode="poll") + + await self.assert_rtt_connection_is_disabled(client) + + async def test_rtt_connection_is_disabled_auto(self): + envs = [ + {"AWS_EXECUTION_ENV": "AWS_Lambda_python3.9"}, + {"FUNCTIONS_WORKER_RUNTIME": "python"}, + {"K_SERVICE": "gcpservicename"}, + {"FUNCTION_NAME": "gcpfunctionname"}, + {"VERCEL": "1"}, + ] + for env in envs: + with patch.dict("os.environ", env): + client = await self.async_rs_or_single_client(serverMonitoringMode="auto") + await self.assert_rtt_connection_is_disabled(client) + + async def assert_rtt_connection_is_disabled(self, client): + await client.admin.command("ping") + for _, server in client._topology._servers.items(): + monitor = server._monitor + self.assertFalse(monitor._stream) + if _IS_SYNC: + self.assertIsNone(monitor._rtt_monitor._executor._thread) + else: + self.assertIsNone(monitor._rtt_monitor._executor._task) + + +class MockTCPHandler(socketserver.BaseRequestHandler): + def handle(self): + self.server.events.append("client connected") + if self.request.recv(1024).strip(): + self.server.events.append("client hello received") + self.request.close() + + +class TCPServer(socketserver.TCPServer): + allow_reuse_address = True + + def handle_request_and_shutdown(self): + self.handle_request() + self.server_close() + + +class TestHeartbeatStartOrdering(AsyncPyMongoTestCase): + async def test_heartbeat_start_ordering(self): + events = [] + listener = HeartbeatEventsListListener(events) + + if _IS_SYNC: + server = TCPServer(("localhost", 9999), MockTCPHandler) + server.events = events + server_thread = ConcurrentRunner(target=server.handle_request_and_shutdown) + await server_thread.start() + _c = await self.simple_client( + "mongodb://localhost:9999", + serverSelectionTimeoutMS=500, + event_listeners=(listener,), + ) + await server_thread.join() + listener.wait_for_event(ServerHeartbeatStartedEvent, 1) + listener.wait_for_event(ServerHeartbeatFailedEvent, 1) + + else: + + async def handle_client(reader: StreamReader, writer: StreamWriter): + events.append("client connected") + if (await reader.read(1024)).strip(): + events.append("client hello received") + writer.close() + await writer.wait_closed() + + server = await asyncio.start_server(handle_client, "localhost", 9999) + server.events = events + await server.start_serving() + _c = self.simple_client( + "mongodb://localhost:9999", + serverSelectionTimeoutMS=500, + event_listeners=(listener,), + ) + await _c.aconnect() + + await listener.async_wait_for_event(ServerHeartbeatStartedEvent, 1) + await listener.async_wait_for_event(ServerHeartbeatFailedEvent, 1) + + server.close() + await server.wait_closed() + await _c.close() + + self.assertEqual( + events, + [ + "serverHeartbeatStartedEvent", + "client connected", + "client hello received", + "serverHeartbeatFailedEvent", + ], + ) + + +# Generate unified tests. +globals().update(generate_test_classes(os.path.join(SDAM_PATH, "unified"), module=__name__)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_dns.py b/test/asynchronous/test_dns.py new file mode 100644 index 0000000000..5666612218 --- /dev/null +++ b/test/asynchronous/test_dns.py @@ -0,0 +1,308 @@ +# Copyright 2017 MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run the SRV support tests.""" +from __future__ import annotations + +import glob +import json +import os +import pathlib +import sys + +sys.path[0:0] = [""] + +from test.asynchronous import ( + AsyncIntegrationTest, + AsyncPyMongoTestCase, + async_client_context, + unittest, +) +from test.utils_shared import async_wait_until +from unittest.mock import MagicMock, patch + +from pymongo.asynchronous.uri_parser import parse_uri +from pymongo.common import validate_read_preference_tags +from pymongo.errors import ConfigurationError +from pymongo.uri_parser_shared import split_hosts + +_IS_SYNC = False + + +class TestDNSRepl(AsyncPyMongoTestCase): + if _IS_SYNC: + TEST_PATH = os.path.join( + pathlib.Path(__file__).resolve().parent, "srv_seedlist", "replica-set" + ) + else: + TEST_PATH = os.path.join( + pathlib.Path(__file__).resolve().parent.parent, "srv_seedlist", "replica-set" + ) + load_balanced = False + + @async_client_context.require_replica_set + def asyncSetUp(self): + pass + + +class TestDNSLoadBalanced(AsyncPyMongoTestCase): + if _IS_SYNC: + TEST_PATH = os.path.join( + pathlib.Path(__file__).resolve().parent, "srv_seedlist", "load-balanced" + ) + else: + TEST_PATH = os.path.join( + pathlib.Path(__file__).resolve().parent.parent, "srv_seedlist", "load-balanced" + ) + load_balanced = True + + @async_client_context.require_load_balancer + def asyncSetUp(self): + pass + + +class TestDNSSharded(AsyncPyMongoTestCase): + if _IS_SYNC: + TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "srv_seedlist", "sharded") + else: + TEST_PATH = os.path.join( + pathlib.Path(__file__).resolve().parent.parent, "srv_seedlist", "sharded" + ) + load_balanced = False + + @async_client_context.require_mongos + def asyncSetUp(self): + pass + + +def create_test(test_case): + async def run_test(self): + uri = test_case["uri"] + seeds = test_case.get("seeds") + num_seeds = test_case.get("numSeeds", len(seeds or [])) + hosts = test_case.get("hosts") + num_hosts = test_case.get("numHosts", len(hosts or [])) + + options = test_case.get("options", {}) + if "ssl" in options: + options["tls"] = options.pop("ssl") + parsed_options = test_case.get("parsed_options") + # See DRIVERS-1324, unless tls is explicitly set to False we need TLS. + needs_tls = not (options and (options.get("ssl") is False or options.get("tls") is False)) + if needs_tls and not async_client_context.tls: + self.skipTest("this test requires a TLS cluster") + if not needs_tls and async_client_context.tls: + self.skipTest("this test requires a non-TLS cluster") + + if seeds: + seeds = split_hosts(",".join(seeds)) + if hosts: + hosts = frozenset(split_hosts(",".join(hosts))) + + if seeds or num_seeds: + result = await parse_uri(uri, validate=True) + if seeds is not None: + self.assertEqual(sorted(result["nodelist"]), sorted(seeds)) + if num_seeds is not None: + self.assertEqual(len(result["nodelist"]), num_seeds) + if options: + opts = result["options"] + if "readpreferencetags" in opts: + rpts = validate_read_preference_tags( + "readPreferenceTags", opts.pop("readpreferencetags") + ) + opts["readPreferenceTags"] = rpts + self.assertEqual(result["options"], options) + if parsed_options: + for opt, expected in parsed_options.items(): + if opt == "user": + self.assertEqual(result["username"], expected) + elif opt == "password": + self.assertEqual(result["password"], expected) + elif opt == "auth_database" or opt == "db": + self.assertEqual(result["database"], expected) + + hostname = next(iter(async_client_context.client.nodes))[0] + # The replica set members must be configured as 'localhost'. + if hostname == "localhost": + copts = async_client_context.default_client_options.copy() + # Remove tls since SRV parsing should add it automatically. + copts.pop("tls", None) + if async_client_context.tls: + # Our test certs don't support the SRV hosts used in these + # tests. + copts["tlsAllowInvalidHostnames"] = True + + client = self.simple_client(uri, **copts) + if client._options.connect: + await client.aconnect() + if num_seeds is not None: + self.assertEqual(len(client._topology_settings.seeds), num_seeds) + if hosts is not None: + await async_wait_until( + lambda: hosts == client.nodes, "match test hosts to client nodes" + ) + if num_hosts is not None: + await async_wait_until( + lambda: num_hosts == len(client.nodes), "wait to connect to num_hosts" + ) + if test_case.get("ping", True): + await client.admin.command("ping") + # XXX: we should block until SRV poller runs at least once + # and re-run these assertions. + else: + try: + await parse_uri(uri) + except (ConfigurationError, ValueError): + pass + else: + self.fail("failed to raise an exception") + + return run_test + + +def create_tests(cls): + for filename in glob.glob(os.path.join(cls.TEST_PATH, "*.json")): + test_suffix, _ = os.path.splitext(os.path.basename(filename)) + with open(filename) as dns_test_file: + test_method = create_test(json.load(dns_test_file)) + setattr(cls, "test_" + test_suffix, test_method) + + +create_tests(TestDNSRepl) +create_tests(TestDNSLoadBalanced) +create_tests(TestDNSSharded) + + +class TestParsingErrors(AsyncPyMongoTestCase): + async def test_invalid_host(self): + with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: an IP address is not"): + client = self.simple_client("mongodb+srv://127.0.0.1") + await client.aconnect() + with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: an IP address is not"): + client = self.simple_client("mongodb+srv://[::1]") + await client.aconnect() + + +class IsolatedAsyncioTestCaseInsensitive(AsyncIntegrationTest): + async def test_connect_case_insensitive(self): + client = self.simple_client("mongodb+srv://TEST1.TEST.BUILD.10GEN.cc/") + await client.aconnect() + self.assertGreater(len(client.topology_description.server_descriptions()), 1) + + +class TestInitialDnsSeedlistDiscovery(AsyncPyMongoTestCase): + """ + Initial DNS Seedlist Discovery prose tests + https://github.com/mongodb/specifications/blob/0a7a8b5/source/initial-dns-seedlist-discovery/tests/README.md#prose-tests + """ + + async def run_initial_dns_seedlist_discovery_prose_tests(self, test_cases): + for case in test_cases: + with patch("dns.asyncresolver.resolve") as mock_resolver: + + async def mock_resolve(query, record_type, *args, **kwargs): + mock_srv = MagicMock() + mock_srv.target.to_text.return_value = case["mock_target"] + return [mock_srv] + + mock_resolver.side_effect = mock_resolve + domain = case["query"].split("._tcp.")[1] + connection_string = f"mongodb+srv://{domain}" + if "expected_error" not in case: + await parse_uri(connection_string) + else: + try: + await parse_uri(connection_string) + except ConfigurationError as e: + self.assertIn(case["expected_error"], str(e)) + else: + self.fail(f"ConfigurationError was not raised for query: {case['query']}") + + async def test_1_allow_srv_hosts_with_fewer_than_three_dot_separated_parts(self): + with patch("dns.asyncresolver.resolve"): + await parse_uri("mongodb+srv://localhost/") + await parse_uri("mongodb+srv://mongo.local/") + + async def test_2_throw_when_return_address_does_not_end_with_srv_domain(self): + test_cases = [ + { + "query": "_mongodb._tcp.localhost", + "mock_target": "localhost.mongodb", + "expected_error": "Invalid SRV host", + }, + { + "query": "_mongodb._tcp.blogs.mongodb.com", + "mock_target": "blogs.evil.com", + "expected_error": "Invalid SRV host", + }, + { + "query": "_mongodb._tcp.blogs.mongo.local", + "mock_target": "test_1.evil.com", + "expected_error": "Invalid SRV host", + }, + ] + await self.run_initial_dns_seedlist_discovery_prose_tests(test_cases) + + async def test_3_throw_when_return_address_is_identical_to_srv_hostname(self): + test_cases = [ + { + "query": "_mongodb._tcp.localhost", + "mock_target": "localhost", + "expected_error": "Invalid SRV host", + }, + { + "query": "_mongodb._tcp.mongo.local", + "mock_target": "mongo.local", + "expected_error": "Invalid SRV host", + }, + ] + await self.run_initial_dns_seedlist_discovery_prose_tests(test_cases) + + async def test_4_throw_when_return_address_does_not_contain_dot_separating_shared_part_of_domain( + self + ): + test_cases = [ + { + "query": "_mongodb._tcp.localhost", + "mock_target": "test_1.cluster_1localhost", + "expected_error": "Invalid SRV host", + }, + { + "query": "_mongodb._tcp.mongo.local", + "mock_target": "test_1.my_hostmongo.local", + "expected_error": "Invalid SRV host", + }, + { + "query": "_mongodb._tcp.blogs.mongodb.com", + "mock_target": "cluster.testmongodb.com", + "expected_error": "Invalid SRV host", + }, + ] + await self.run_initial_dns_seedlist_discovery_prose_tests(test_cases) + + async def test_5_when_srv_hostname_has_two_dot_separated_parts_it_is_valid_for_the_returned_hostname_to_be_identical( + self + ): + test_cases = [ + { + "query": "_mongodb._tcp.blogs.mongodb.com", + "mock_target": "blogs.mongodb.com", + }, + ] + await self.run_initial_dns_seedlist_discovery_prose_tests(test_cases) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index 2b22bd8b76..9e8758a1cd 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -41,6 +41,7 @@ from pymongo.asynchronous.collection import AsyncCollection from pymongo.asynchronous.helpers import anext from pymongo.daemon import _spawn_daemon +from pymongo.uri_parser_shared import _parse_kms_tls_options try: from pymongo.pyopenssl_context import IS_PYOPENSSL @@ -64,7 +65,7 @@ KMIP_CREDS, LOCAL_MASTER_KEY, ) -from test.utils import ( +from test.utils_shared import ( AllowListEventListener, OvertCommandListener, TopologyEventListener, @@ -73,7 +74,7 @@ is_greenthread_patched, ) -from bson import DatetimeMS, Decimal128, encode, json_util +from bson import BSON, DatetimeMS, Decimal128, encode, json_util from bson.binary import UUID_SUBTYPE, Binary, UuidRepresentation from bson.codec_options import CodecOptions from bson.errors import BSONError @@ -94,6 +95,7 @@ EncryptionError, InvalidOperation, OperationFailure, + PyMongoError, ServerSelectionTimeoutError, WriteError, ) @@ -140,7 +142,7 @@ def test_init(self): self.assertEqual(opts._mongocryptd_bypass_spawn, False) self.assertEqual(opts._mongocryptd_spawn_path, "mongocryptd") self.assertEqual(opts._mongocryptd_spawn_args, ["--idleShutdownTimeoutSecs=60"]) - self.assertEqual(opts._kms_ssl_contexts, {}) + self.assertEqual(opts._kms_tls_options, None) @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") def test_init_spawn_args(self): @@ -164,30 +166,38 @@ def test_init_spawn_args(self): ) @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") - def test_init_kms_tls_options(self): + async def test_init_kms_tls_options(self): # Error cases: + opts = AutoEncryptionOpts({}, "k.d", kms_tls_options={"kmip": 1}) with self.assertRaisesRegex(TypeError, r'kms_tls_options\["kmip"\] must be a dict'): - AutoEncryptionOpts({}, "k.d", kms_tls_options={"kmip": 1}) + AsyncMongoClient(auto_encryption_opts=opts) + tls_opts: Any for tls_opts in [ {"kmip": {"tls": True, "tlsInsecure": True}}, {"kmip": {"tls": True, "tlsAllowInvalidCertificates": True}}, {"kmip": {"tls": True, "tlsAllowInvalidHostnames": True}}, ]: + opts = AutoEncryptionOpts({}, "k.d", kms_tls_options=tls_opts) with self.assertRaisesRegex(ConfigurationError, "Insecure TLS options prohibited"): - opts = AutoEncryptionOpts({}, "k.d", kms_tls_options=tls_opts) + AsyncMongoClient(auto_encryption_opts=opts) + opts = AutoEncryptionOpts( + {}, "k.d", kms_tls_options={"kmip": {"tlsCAFile": "does-not-exist"}} + ) with self.assertRaises(FileNotFoundError): - AutoEncryptionOpts({}, "k.d", kms_tls_options={"kmip": {"tlsCAFile": "does-not-exist"}}) + AsyncMongoClient(auto_encryption_opts=opts) # Success cases: tls_opts: Any for tls_opts in [None, {}]: opts = AutoEncryptionOpts({}, "k.d", kms_tls_options=tls_opts) - self.assertEqual(opts._kms_ssl_contexts, {}) + kms_tls_contexts = _parse_kms_tls_options(opts._kms_tls_options, _IS_SYNC) + self.assertEqual(kms_tls_contexts, {}) opts = AutoEncryptionOpts({}, "k.d", kms_tls_options={"kmip": {"tls": True}, "aws": {}}) - ctx = opts._kms_ssl_contexts["kmip"] + _kms_ssl_contexts = _parse_kms_tls_options(opts._kms_tls_options, _IS_SYNC) + ctx = _kms_ssl_contexts["kmip"] self.assertEqual(ctx.check_hostname, True) self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED) - ctx = opts._kms_ssl_contexts["aws"] + ctx = _kms_ssl_contexts["aws"] self.assertEqual(ctx.check_hostname, True) self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED) opts = AutoEncryptionOpts( @@ -195,7 +205,8 @@ def test_init_kms_tls_options(self): "k.d", kms_tls_options={"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM}}, ) - ctx = opts._kms_ssl_contexts["kmip"] + _kms_ssl_contexts = _parse_kms_tls_options(opts._kms_tls_options, _IS_SYNC) + ctx = _kms_ssl_contexts["kmip"] self.assertEqual(ctx.check_hostname, True) self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED) @@ -739,7 +750,7 @@ def allowable_errors(self, op): return errors -async def create_test(scenario_def, test, name): +def create_test(scenario_def, test, name): @async_client_context.require_test_commands async def run_scenario(self): await self.run_scenario(scenario_def, test) @@ -2224,7 +2235,7 @@ async def test_05_tlsDisableOCSPEndpointCheck_is_permitted(self): encryption = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=options ) - ctx = encryption._io_callbacks.opts._kms_ssl_contexts["aws"] + ctx = encryption._io_callbacks._kms_ssl_contexts["aws"] if not hasattr(ctx, "check_ocsp_endpoint"): raise self.skipTest("OCSP not enabled") self.assertFalse(ctx.check_ocsp_endpoint) @@ -2419,6 +2430,310 @@ async def test_05_roundtrip_encrypted_unindexed(self): self.assertEqual(decrypted, val) +# https://github.com/mongodb/specifications/blob/527e22d5090ec48bf1e144c45fc831de0f1935f6/source/client-side-encryption/tests/README.md#25-test-lookup +class TestLookupProse(AsyncEncryptionIntegrationTest): + @async_client_context.require_no_standalone + @async_client_context.require_version_min(7, 0, -1) + async def asyncSetUp(self): + await super().asyncSetUp() + encrypted_client = await self.async_rs_or_single_client( + auto_encryption_opts=AutoEncryptionOpts( + key_vault_namespace="db.keyvault", + kms_providers={"local": {"key": LOCAL_MASTER_KEY}}, + ) + ) + await encrypted_client.drop_database("db") + + key_doc = json_data("etc", "data", "lookup", "key-doc.json") + await create_key_vault(encrypted_client.db.keyvault, key_doc) + self.addAsyncCleanup(async_client_context.client.drop_database, "db") + + await encrypted_client.db.create_collection( + "csfle", + validator={"$jsonSchema": json_data("etc", "data", "lookup", "schema-csfle.json")}, + ) + await encrypted_client.db.create_collection( + "csfle2", + validator={"$jsonSchema": json_data("etc", "data", "lookup", "schema-csfle2.json")}, + ) + await encrypted_client.db.create_collection( + "qe", encryptedFields=json_data("etc", "data", "lookup", "schema-qe.json") + ) + await encrypted_client.db.create_collection( + "qe2", encryptedFields=json_data("etc", "data", "lookup", "schema-qe2.json") + ) + await encrypted_client.db.create_collection("no_schema") + await encrypted_client.db.create_collection("no_schema2") + + unencrypted_client = await self.async_rs_or_single_client() + + await encrypted_client.db.csfle.insert_one({"csfle": "csfle"}) + doc = await unencrypted_client.db.csfle.find_one() + self.assertTrue(isinstance(doc["csfle"], Binary)) + await encrypted_client.db.csfle2.insert_one({"csfle2": "csfle2"}) + doc = await unencrypted_client.db.csfle2.find_one() + self.assertTrue(isinstance(doc["csfle2"], Binary)) + await encrypted_client.db.qe.insert_one({"qe": "qe"}) + doc = await unencrypted_client.db.qe.find_one() + self.assertTrue(isinstance(doc["qe"], Binary)) + await encrypted_client.db.qe2.insert_one({"qe2": "qe2"}) + doc = await unencrypted_client.db.qe2.find_one() + self.assertTrue(isinstance(doc["qe2"], Binary)) + await encrypted_client.db.no_schema.insert_one({"no_schema": "no_schema"}) + await encrypted_client.db.no_schema2.insert_one({"no_schema2": "no_schema2"}) + + await encrypted_client.close() + await unencrypted_client.close() + + @async_client_context.require_version_min(8, 1, -1) + async def test_1_csfle_joins_no_schema(self): + encrypted_client = await self.async_rs_or_single_client( + auto_encryption_opts=AutoEncryptionOpts( + key_vault_namespace="db.keyvault", + kms_providers={"local": {"key": LOCAL_MASTER_KEY}}, + ) + ) + doc = await anext( + await encrypted_client.db.csfle.aggregate( + [ + {"$match": {"csfle": "csfle"}}, + { + "$lookup": { + "from": "no_schema", + "as": "matched", + "pipeline": [ + {"$match": {"no_schema": "no_schema"}}, + {"$project": {"_id": 0}}, + ], + } + }, + {"$project": {"_id": 0}}, + ] + ) + ) + self.assertEqual(doc, {"csfle": "csfle", "matched": [{"no_schema": "no_schema"}]}) + + @async_client_context.require_version_min(8, 1, -1) + async def test_2_qe_joins_no_schema(self): + encrypted_client = await self.async_rs_or_single_client( + auto_encryption_opts=AutoEncryptionOpts( + key_vault_namespace="db.keyvault", + kms_providers={"local": {"key": LOCAL_MASTER_KEY}}, + ) + ) + doc = await anext( + await encrypted_client.db.qe.aggregate( + [ + {"$match": {"qe": "qe"}}, + { + "$lookup": { + "from": "no_schema", + "as": "matched", + "pipeline": [ + {"$match": {"no_schema": "no_schema"}}, + {"$project": {"_id": 0, "__safeContent__": 0}}, + ], + } + }, + {"$project": {"_id": 0, "__safeContent__": 0}}, + ] + ) + ) + self.assertEqual(doc, {"qe": "qe", "matched": [{"no_schema": "no_schema"}]}) + + @async_client_context.require_version_min(8, 1, -1) + async def test_3_no_schema_joins_csfle(self): + encrypted_client = await self.async_rs_or_single_client( + auto_encryption_opts=AutoEncryptionOpts( + key_vault_namespace="db.keyvault", + kms_providers={"local": {"key": LOCAL_MASTER_KEY}}, + ) + ) + doc = await anext( + await encrypted_client.db.no_schema.aggregate( + [ + {"$match": {"no_schema": "no_schema"}}, + { + "$lookup": { + "from": "csfle", + "as": "matched", + "pipeline": [{"$match": {"csfle": "csfle"}}, {"$project": {"_id": 0}}], + } + }, + {"$project": {"_id": 0}}, + ] + ) + ) + self.assertEqual(doc, {"no_schema": "no_schema", "matched": [{"csfle": "csfle"}]}) + + @async_client_context.require_version_min(8, 1, -1) + async def test_4_no_schema_joins_qe(self): + encrypted_client = await self.async_rs_or_single_client( + auto_encryption_opts=AutoEncryptionOpts( + key_vault_namespace="db.keyvault", + kms_providers={"local": {"key": LOCAL_MASTER_KEY}}, + ) + ) + doc = await anext( + await encrypted_client.db.no_schema.aggregate( + [ + {"$match": {"no_schema": "no_schema"}}, + { + "$lookup": { + "from": "qe", + "as": "matched", + "pipeline": [ + {"$match": {"qe": "qe"}}, + {"$project": {"_id": 0, "__safeContent__": 0}}, + ], + } + }, + {"$project": {"_id": 0}}, + ] + ) + ) + self.assertEqual(doc, {"no_schema": "no_schema", "matched": [{"qe": "qe"}]}) + + @async_client_context.require_version_min(8, 1, -1) + async def test_5_csfle_joins_csfle2(self): + encrypted_client = await self.async_rs_or_single_client( + auto_encryption_opts=AutoEncryptionOpts( + key_vault_namespace="db.keyvault", + kms_providers={"local": {"key": LOCAL_MASTER_KEY}}, + ) + ) + doc = await anext( + await encrypted_client.db.csfle.aggregate( + [ + {"$match": {"csfle": "csfle"}}, + { + "$lookup": { + "from": "csfle2", + "as": "matched", + "pipeline": [ + {"$match": {"csfle2": "csfle2"}}, + {"$project": {"_id": 0}}, + ], + } + }, + {"$project": {"_id": 0}}, + ] + ) + ) + self.assertEqual(doc, {"csfle": "csfle", "matched": [{"csfle2": "csfle2"}]}) + + @async_client_context.require_version_min(8, 1, -1) + async def test_6_qe_joins_qe2(self): + encrypted_client = await self.async_rs_or_single_client( + auto_encryption_opts=AutoEncryptionOpts( + key_vault_namespace="db.keyvault", + kms_providers={"local": {"key": LOCAL_MASTER_KEY}}, + ) + ) + doc = await anext( + await encrypted_client.db.qe.aggregate( + [ + {"$match": {"qe": "qe"}}, + { + "$lookup": { + "from": "qe2", + "as": "matched", + "pipeline": [ + {"$match": {"qe2": "qe2"}}, + {"$project": {"_id": 0, "__safeContent__": 0}}, + ], + } + }, + {"$project": {"_id": 0, "__safeContent__": 0}}, + ] + ) + ) + self.assertEqual(doc, {"qe": "qe", "matched": [{"qe2": "qe2"}]}) + + @async_client_context.require_version_min(8, 1, -1) + async def test_7_no_schema_joins_no_schema2(self): + encrypted_client = await self.async_rs_or_single_client( + auto_encryption_opts=AutoEncryptionOpts( + key_vault_namespace="db.keyvault", + kms_providers={"local": {"key": LOCAL_MASTER_KEY}}, + ) + ) + doc = await anext( + await encrypted_client.db.no_schema.aggregate( + [ + {"$match": {"no_schema": "no_schema"}}, + { + "$lookup": { + "from": "no_schema2", + "as": "matched", + "pipeline": [ + {"$match": {"no_schema2": "no_schema2"}}, + {"$project": {"_id": 0}}, + ], + } + }, + {"$project": {"_id": 0}}, + ] + ) + ) + self.assertEqual(doc, {"no_schema": "no_schema", "matched": [{"no_schema2": "no_schema2"}]}) + + @async_client_context.require_version_min(8, 1, -1) + async def test_8_csfle_joins_qe(self): + encrypted_client = await self.async_rs_or_single_client( + auto_encryption_opts=AutoEncryptionOpts( + key_vault_namespace="db.keyvault", + kms_providers={"local": {"key": LOCAL_MASTER_KEY}}, + ) + ) + with self.assertRaises(PyMongoError) as exc: + _ = await anext( + await encrypted_client.db.csfle.aggregate( + [ + {"$match": {"csfle": "qe"}}, + { + "$lookup": { + "from": "qe", + "as": "matched", + "pipeline": [{"$match": {"qe": "qe"}}, {"$project": {"_id": 0}}], + } + }, + {"$project": {"_id": 0}}, + ] + ) + ) + self.assertIn("not supported", str(exc)) + + @async_client_context.require_version_max(8, 1, -1) + async def test_9_error(self): + encrypted_client = await self.async_rs_or_single_client( + auto_encryption_opts=AutoEncryptionOpts( + key_vault_namespace="db.keyvault", + kms_providers={"local": {"key": LOCAL_MASTER_KEY}}, + ) + ) + with self.assertRaises(PyMongoError) as exc: + _ = await anext( + await encrypted_client.db.csfle.aggregate( + [ + {"$match": {"csfle": "csfle"}}, + { + "$lookup": { + "from": "no_schema", + "as": "matched", + "pipeline": [ + {"$match": {"no_schema": "no_schema"}}, + {"$project": {"_id": 0}}, + ], + } + }, + {"$project": {"_id": 0}}, + ] + ) + ) + self.assertIn("Upgrade", str(exc)) + + # https://github.com/mongodb/specifications/blob/072601/source/client-side-encryption/tests/README.md#rewrap class TestRewrapWithSeparateClientEncryption(AsyncEncryptionIntegrationTest): MASTER_KEYS: Mapping[str, Mapping[str, Any]] = { @@ -2982,9 +3297,10 @@ async def test_02_no_fields(self): ) async def test_03_invalid_keyid(self): + # checkAuthForCreateCollection can be removed when SERVER-102101 is fixed. with self.assertRaisesRegex( EncryptedCollectionError, - "create.encryptedFields.fields.keyId' is the wrong type 'bool', expected type 'binData", + "(create|checkAuthForCreateCollection).encryptedFields.fields.keyId' is the wrong type 'bool', expected type 'binData", ): await self.client_encryption.create_encrypted_collection( database=self.db, diff --git a/test/asynchronous/test_examples.py b/test/asynchronous/test_examples.py new file mode 100644 index 0000000000..9e9b208f51 --- /dev/null +++ b/test/asynchronous/test_examples.py @@ -0,0 +1,1456 @@ +# Copyright 2017 MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MongoDB documentation examples in Python.""" +from __future__ import annotations + +import asyncio +import datetime +import functools +import sys +import threading +import time +from test.asynchronous.helpers import ConcurrentRunner + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.utils_shared import async_wait_until + +import pymongo +from pymongo.asynchronous.helpers import anext +from pymongo.errors import ConnectionFailure, OperationFailure +from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import ReadPreference +from pymongo.server_api import ServerApi +from pymongo.write_concern import WriteConcern + +_IS_SYNC = False + + +class TestSampleShellCommands(AsyncIntegrationTest): + async def asyncSetUp(self): + await super().asyncSetUp() + await self.db.inventory.drop() + + async def asyncTearDown(self): + # Run after every test. + await self.db.inventory.drop() + await self.client.drop_database("pymongo_test") + + async def test_first_three_examples(self): + db = self.db + + # Start Example 1 + await db.inventory.insert_one( + { + "item": "canvas", + "qty": 100, + "tags": ["cotton"], + "size": {"h": 28, "w": 35.5, "uom": "cm"}, + } + ) + # End Example 1 + + self.assertEqual(await db.inventory.count_documents({}), 1) + + # Start Example 2 + cursor = db.inventory.find({"item": "canvas"}) + # End Example 2 + + self.assertEqual(len(await cursor.to_list()), 1) + + # Start Example 3 + await db.inventory.insert_many( + [ + { + "item": "journal", + "qty": 25, + "tags": ["blank", "red"], + "size": {"h": 14, "w": 21, "uom": "cm"}, + }, + { + "item": "mat", + "qty": 85, + "tags": ["gray"], + "size": {"h": 27.9, "w": 35.5, "uom": "cm"}, + }, + { + "item": "mousepad", + "qty": 25, + "tags": ["gel", "blue"], + "size": {"h": 19, "w": 22.85, "uom": "cm"}, + }, + ] + ) + # End Example 3 + + self.assertEqual(await db.inventory.count_documents({}), 4) + + async def test_query_top_level_fields(self): + db = self.db + + # Start Example 6 + await db.inventory.insert_many( + [ + { + "item": "journal", + "qty": 25, + "size": {"h": 14, "w": 21, "uom": "cm"}, + "status": "A", + }, + { + "item": "notebook", + "qty": 50, + "size": {"h": 8.5, "w": 11, "uom": "in"}, + "status": "A", + }, + { + "item": "paper", + "qty": 100, + "size": {"h": 8.5, "w": 11, "uom": "in"}, + "status": "D", + }, + { + "item": "planner", + "qty": 75, + "size": {"h": 22.85, "w": 30, "uom": "cm"}, + "status": "D", + }, + { + "item": "postcard", + "qty": 45, + "size": {"h": 10, "w": 15.25, "uom": "cm"}, + "status": "A", + }, + ] + ) + # End Example 6 + + self.assertEqual(await db.inventory.count_documents({}), 5) + + # Start Example 7 + cursor = db.inventory.find({}) + # End Example 7 + + self.assertEqual(len(await cursor.to_list()), 5) + + # Start Example 9 + cursor = db.inventory.find({"status": "D"}) + # End Example 9 + + self.assertEqual(len(await cursor.to_list()), 2) + + # Start Example 10 + cursor = db.inventory.find({"status": {"$in": ["A", "D"]}}) + # End Example 10 + + self.assertEqual(len(await cursor.to_list()), 5) + + # Start Example 11 + cursor = db.inventory.find({"status": "A", "qty": {"$lt": 30}}) + # End Example 11 + + self.assertEqual(len(await cursor.to_list()), 1) + + # Start Example 12 + cursor = db.inventory.find({"$or": [{"status": "A"}, {"qty": {"$lt": 30}}]}) + # End Example 12 + + self.assertEqual(len(await cursor.to_list()), 3) + + # Start Example 13 + cursor = db.inventory.find( + {"status": "A", "$or": [{"qty": {"$lt": 30}}, {"item": {"$regex": "^p"}}]} + ) + # End Example 13 + + self.assertEqual(len(await cursor.to_list()), 2) + + async def test_query_embedded_documents(self): + db = self.db + + # Start Example 14 + # Subdocument key order matters in a few of these examples so we have + # to use bson.son.SON instead of a Python dict. + from bson.son import SON + + await db.inventory.insert_many( + [ + { + "item": "journal", + "qty": 25, + "size": SON([("h", 14), ("w", 21), ("uom", "cm")]), + "status": "A", + }, + { + "item": "notebook", + "qty": 50, + "size": SON([("h", 8.5), ("w", 11), ("uom", "in")]), + "status": "A", + }, + { + "item": "paper", + "qty": 100, + "size": SON([("h", 8.5), ("w", 11), ("uom", "in")]), + "status": "D", + }, + { + "item": "planner", + "qty": 75, + "size": SON([("h", 22.85), ("w", 30), ("uom", "cm")]), + "status": "D", + }, + { + "item": "postcard", + "qty": 45, + "size": SON([("h", 10), ("w", 15.25), ("uom", "cm")]), + "status": "A", + }, + ] + ) + # End Example 14 + + # Start Example 15 + cursor = db.inventory.find({"size": SON([("h", 14), ("w", 21), ("uom", "cm")])}) + # End Example 15 + + self.assertEqual(len(await cursor.to_list()), 1) + + # Start Example 16 + cursor = db.inventory.find({"size": SON([("w", 21), ("h", 14), ("uom", "cm")])}) + # End Example 16 + + self.assertEqual(len(await cursor.to_list()), 0) + + # Start Example 17 + cursor = db.inventory.find({"size.uom": "in"}) + # End Example 17 + + self.assertEqual(len(await cursor.to_list()), 2) + + # Start Example 18 + cursor = db.inventory.find({"size.h": {"$lt": 15}}) + # End Example 18 + + self.assertEqual(len(await cursor.to_list()), 4) + + # Start Example 19 + cursor = db.inventory.find({"size.h": {"$lt": 15}, "size.uom": "in", "status": "D"}) + # End Example 19 + + self.assertEqual(len(await cursor.to_list()), 1) + + async def test_query_arrays(self): + db = self.db + + # Start Example 20 + await db.inventory.insert_many( + [ + {"item": "journal", "qty": 25, "tags": ["blank", "red"], "dim_cm": [14, 21]}, + {"item": "notebook", "qty": 50, "tags": ["red", "blank"], "dim_cm": [14, 21]}, + { + "item": "paper", + "qty": 100, + "tags": ["red", "blank", "plain"], + "dim_cm": [14, 21], + }, + {"item": "planner", "qty": 75, "tags": ["blank", "red"], "dim_cm": [22.85, 30]}, + {"item": "postcard", "qty": 45, "tags": ["blue"], "dim_cm": [10, 15.25]}, + ] + ) + # End Example 20 + + # Start Example 21 + cursor = db.inventory.find({"tags": ["red", "blank"]}) + # End Example 21 + + self.assertEqual(len(await cursor.to_list()), 1) + + # Start Example 22 + cursor = db.inventory.find({"tags": {"$all": ["red", "blank"]}}) + # End Example 22 + + self.assertEqual(len(await cursor.to_list()), 4) + + # Start Example 23 + cursor = db.inventory.find({"tags": "red"}) + # End Example 23 + + self.assertEqual(len(await cursor.to_list()), 4) + + # Start Example 24 + cursor = db.inventory.find({"dim_cm": {"$gt": 25}}) + # End Example 24 + + self.assertEqual(len(await cursor.to_list()), 1) + + # Start Example 25 + cursor = db.inventory.find({"dim_cm": {"$gt": 15, "$lt": 20}}) + # End Example 25 + + self.assertEqual(len(await cursor.to_list()), 4) + + # Start Example 26 + cursor = db.inventory.find({"dim_cm": {"$elemMatch": {"$gt": 22, "$lt": 30}}}) + # End Example 26 + + self.assertEqual(len(await cursor.to_list()), 1) + + # Start Example 27 + cursor = db.inventory.find({"dim_cm.1": {"$gt": 25}}) + # End Example 27 + + self.assertEqual(len(await cursor.to_list()), 1) + + # Start Example 28 + cursor = db.inventory.find({"tags": {"$size": 3}}) + # End Example 28 + + self.assertEqual(len(await cursor.to_list()), 1) + + async def test_query_array_of_documents(self): + db = self.db + + # Start Example 29 + # Subdocument key order matters in a few of these examples so we have + # to use bson.son.SON instead of a Python dict. + from bson.son import SON + + await db.inventory.insert_many( + [ + { + "item": "journal", + "instock": [ + SON([("warehouse", "A"), ("qty", 5)]), + SON([("warehouse", "C"), ("qty", 15)]), + ], + }, + {"item": "notebook", "instock": [SON([("warehouse", "C"), ("qty", 5)])]}, + { + "item": "paper", + "instock": [ + SON([("warehouse", "A"), ("qty", 60)]), + SON([("warehouse", "B"), ("qty", 15)]), + ], + }, + { + "item": "planner", + "instock": [ + SON([("warehouse", "A"), ("qty", 40)]), + SON([("warehouse", "B"), ("qty", 5)]), + ], + }, + { + "item": "postcard", + "instock": [ + SON([("warehouse", "B"), ("qty", 15)]), + SON([("warehouse", "C"), ("qty", 35)]), + ], + }, + ] + ) + # End Example 29 + + # Start Example 30 + cursor = db.inventory.find({"instock": SON([("warehouse", "A"), ("qty", 5)])}) + # End Example 30 + + self.assertEqual(len(await cursor.to_list()), 1) + + # Start Example 31 + cursor = db.inventory.find({"instock": SON([("qty", 5), ("warehouse", "A")])}) + # End Example 31 + + self.assertEqual(len(await cursor.to_list()), 0) + + # Start Example 32 + cursor = db.inventory.find({"instock.0.qty": {"$lte": 20}}) + # End Example 32 + + self.assertEqual(len(await cursor.to_list()), 3) + + # Start Example 33 + cursor = db.inventory.find({"instock.qty": {"$lte": 20}}) + # End Example 33 + + self.assertEqual(len(await cursor.to_list()), 5) + + # Start Example 34 + cursor = db.inventory.find({"instock": {"$elemMatch": {"qty": 5, "warehouse": "A"}}}) + # End Example 34 + + self.assertEqual(len(await cursor.to_list()), 1) + + # Start Example 35 + cursor = db.inventory.find({"instock": {"$elemMatch": {"qty": {"$gt": 10, "$lte": 20}}}}) + # End Example 35 + + self.assertEqual(len(await cursor.to_list()), 3) + + # Start Example 36 + cursor = db.inventory.find({"instock.qty": {"$gt": 10, "$lte": 20}}) + # End Example 36 + + self.assertEqual(len(await cursor.to_list()), 4) + + # Start Example 37 + cursor = db.inventory.find({"instock.qty": 5, "instock.warehouse": "A"}) + # End Example 37 + + self.assertEqual(len(await cursor.to_list()), 2) + + async def test_query_null(self): + db = self.db + + # Start Example 38 + await db.inventory.insert_many([{"_id": 1, "item": None}, {"_id": 2}]) + # End Example 38 + + # Start Example 39 + cursor = db.inventory.find({"item": None}) + # End Example 39 + + self.assertEqual(len(await cursor.to_list()), 2) + + # Start Example 40 + cursor = db.inventory.find({"item": {"$type": 10}}) + # End Example 40 + + self.assertEqual(len(await cursor.to_list()), 1) + + # Start Example 41 + cursor = db.inventory.find({"item": {"$exists": False}}) + # End Example 41 + + self.assertEqual(len(await cursor.to_list()), 1) + + async def test_projection(self): + db = self.db + + # Start Example 42 + await db.inventory.insert_many( + [ + { + "item": "journal", + "status": "A", + "size": {"h": 14, "w": 21, "uom": "cm"}, + "instock": [{"warehouse": "A", "qty": 5}], + }, + { + "item": "notebook", + "status": "A", + "size": {"h": 8.5, "w": 11, "uom": "in"}, + "instock": [{"warehouse": "C", "qty": 5}], + }, + { + "item": "paper", + "status": "D", + "size": {"h": 8.5, "w": 11, "uom": "in"}, + "instock": [{"warehouse": "A", "qty": 60}], + }, + { + "item": "planner", + "status": "D", + "size": {"h": 22.85, "w": 30, "uom": "cm"}, + "instock": [{"warehouse": "A", "qty": 40}], + }, + { + "item": "postcard", + "status": "A", + "size": {"h": 10, "w": 15.25, "uom": "cm"}, + "instock": [{"warehouse": "B", "qty": 15}, {"warehouse": "C", "qty": 35}], + }, + ] + ) + # End Example 42 + + # Start Example 43 + cursor = db.inventory.find({"status": "A"}) + # End Example 43 + + self.assertEqual(len(await cursor.to_list()), 3) + + # Start Example 44 + cursor = db.inventory.find({"status": "A"}, {"item": 1, "status": 1}) + # End Example 44 + + async for doc in cursor: + self.assertTrue("_id" in doc) + self.assertTrue("item" in doc) + self.assertTrue("status" in doc) + self.assertFalse("size" in doc) + self.assertFalse("instock" in doc) + + # Start Example 45 + cursor = db.inventory.find({"status": "A"}, {"item": 1, "status": 1, "_id": 0}) + # End Example 45 + + async for doc in cursor: + self.assertFalse("_id" in doc) + self.assertTrue("item" in doc) + self.assertTrue("status" in doc) + self.assertFalse("size" in doc) + self.assertFalse("instock" in doc) + + # Start Example 46 + cursor = db.inventory.find({"status": "A"}, {"status": 0, "instock": 0}) + # End Example 46 + + async for doc in cursor: + self.assertTrue("_id" in doc) + self.assertTrue("item" in doc) + self.assertFalse("status" in doc) + self.assertTrue("size" in doc) + self.assertFalse("instock" in doc) + + # Start Example 47 + cursor = db.inventory.find({"status": "A"}, {"item": 1, "status": 1, "size.uom": 1}) + # End Example 47 + + async for doc in cursor: + self.assertTrue("_id" in doc) + self.assertTrue("item" in doc) + self.assertTrue("status" in doc) + self.assertTrue("size" in doc) + self.assertFalse("instock" in doc) + size = doc["size"] + self.assertTrue("uom" in size) + self.assertFalse("h" in size) + self.assertFalse("w" in size) + + # Start Example 48 + cursor = db.inventory.find({"status": "A"}, {"size.uom": 0}) + # End Example 48 + + async for doc in cursor: + self.assertTrue("_id" in doc) + self.assertTrue("item" in doc) + self.assertTrue("status" in doc) + self.assertTrue("size" in doc) + self.assertTrue("instock" in doc) + size = doc["size"] + self.assertFalse("uom" in size) + self.assertTrue("h" in size) + self.assertTrue("w" in size) + + # Start Example 49 + cursor = db.inventory.find({"status": "A"}, {"item": 1, "status": 1, "instock.qty": 1}) + # End Example 49 + + async for doc in cursor: + self.assertTrue("_id" in doc) + self.assertTrue("item" in doc) + self.assertTrue("status" in doc) + self.assertFalse("size" in doc) + self.assertTrue("instock" in doc) + for subdoc in doc["instock"]: + self.assertFalse("warehouse" in subdoc) + self.assertTrue("qty" in subdoc) + + # Start Example 50 + cursor = db.inventory.find( + {"status": "A"}, {"item": 1, "status": 1, "instock": {"$slice": -1}} + ) + # End Example 50 + + async for doc in cursor: + self.assertTrue("_id" in doc) + self.assertTrue("item" in doc) + self.assertTrue("status" in doc) + self.assertFalse("size" in doc) + self.assertTrue("instock" in doc) + self.assertEqual(len(doc["instock"]), 1) + + async def test_update_and_replace(self): + db = self.db + + # Start Example 51 + await db.inventory.insert_many( + [ + { + "item": "canvas", + "qty": 100, + "size": {"h": 28, "w": 35.5, "uom": "cm"}, + "status": "A", + }, + { + "item": "journal", + "qty": 25, + "size": {"h": 14, "w": 21, "uom": "cm"}, + "status": "A", + }, + { + "item": "mat", + "qty": 85, + "size": {"h": 27.9, "w": 35.5, "uom": "cm"}, + "status": "A", + }, + { + "item": "mousepad", + "qty": 25, + "size": {"h": 19, "w": 22.85, "uom": "cm"}, + "status": "P", + }, + { + "item": "notebook", + "qty": 50, + "size": {"h": 8.5, "w": 11, "uom": "in"}, + "status": "P", + }, + { + "item": "paper", + "qty": 100, + "size": {"h": 8.5, "w": 11, "uom": "in"}, + "status": "D", + }, + { + "item": "planner", + "qty": 75, + "size": {"h": 22.85, "w": 30, "uom": "cm"}, + "status": "D", + }, + { + "item": "postcard", + "qty": 45, + "size": {"h": 10, "w": 15.25, "uom": "cm"}, + "status": "A", + }, + { + "item": "sketchbook", + "qty": 80, + "size": {"h": 14, "w": 21, "uom": "cm"}, + "status": "A", + }, + { + "item": "sketch pad", + "qty": 95, + "size": {"h": 22.85, "w": 30.5, "uom": "cm"}, + "status": "A", + }, + ] + ) + # End Example 51 + + # Start Example 52 + await db.inventory.update_one( + {"item": "paper"}, + {"$set": {"size.uom": "cm", "status": "P"}, "$currentDate": {"lastModified": True}}, + ) + # End Example 52 + + async for doc in db.inventory.find({"item": "paper"}): + self.assertEqual(doc["size"]["uom"], "cm") + self.assertEqual(doc["status"], "P") + self.assertTrue("lastModified" in doc) + + # Start Example 53 + await db.inventory.update_many( + {"qty": {"$lt": 50}}, + {"$set": {"size.uom": "in", "status": "P"}, "$currentDate": {"lastModified": True}}, + ) + # End Example 53 + + async for doc in db.inventory.find({"qty": {"$lt": 50}}): + self.assertEqual(doc["size"]["uom"], "in") + self.assertEqual(doc["status"], "P") + self.assertTrue("lastModified" in doc) + + # Start Example 54 + await db.inventory.replace_one( + {"item": "paper"}, + { + "item": "paper", + "instock": [{"warehouse": "A", "qty": 60}, {"warehouse": "B", "qty": 40}], + }, + ) + # End Example 54 + + async for doc in db.inventory.find({"item": "paper"}, {"_id": 0}): + self.assertEqual(len(doc.keys()), 2) + self.assertTrue("item" in doc) + self.assertTrue("instock" in doc) + self.assertEqual(len(doc["instock"]), 2) + + async def test_delete(self): + db = self.db + + # Start Example 55 + await db.inventory.insert_many( + [ + { + "item": "journal", + "qty": 25, + "size": {"h": 14, "w": 21, "uom": "cm"}, + "status": "A", + }, + { + "item": "notebook", + "qty": 50, + "size": {"h": 8.5, "w": 11, "uom": "in"}, + "status": "P", + }, + { + "item": "paper", + "qty": 100, + "size": {"h": 8.5, "w": 11, "uom": "in"}, + "status": "D", + }, + { + "item": "planner", + "qty": 75, + "size": {"h": 22.85, "w": 30, "uom": "cm"}, + "status": "D", + }, + { + "item": "postcard", + "qty": 45, + "size": {"h": 10, "w": 15.25, "uom": "cm"}, + "status": "A", + }, + ] + ) + # End Example 55 + + self.assertEqual(await db.inventory.count_documents({}), 5) + + # Start Example 57 + await db.inventory.delete_many({"status": "A"}) + # End Example 57 + + self.assertEqual(await db.inventory.count_documents({}), 3) + + # Start Example 58 + await db.inventory.delete_one({"status": "D"}) + # End Example 58 + + self.assertEqual(await db.inventory.count_documents({}), 2) + + # Start Example 56 + await db.inventory.delete_many({}) + # End Example 56 + + self.assertEqual(await db.inventory.count_documents({}), 0) + + @async_client_context.require_change_streams + async def test_change_streams(self): + db = self.db + done = False + + async def insert_docs(): + nonlocal done + while not done: + await db.inventory.insert_one({"username": "alice"}) + await db.inventory.delete_one({"username": "alice"}) + await asyncio.sleep(0.005) + + t = ConcurrentRunner(target=insert_docs) + await t.start() + + try: + # 1. The database for reactive, real-time applications + # Start Changestream Example 1 + cursor = await db.inventory.watch() + await anext(cursor) + # End Changestream Example 1 + await cursor.close() + + # Start Changestream Example 2 + cursor = await db.inventory.watch(full_document="updateLookup") + await anext(cursor) + # End Changestream Example 2 + await cursor.close() + + # Start Changestream Example 3 + resume_token = cursor.resume_token + cursor = await db.inventory.watch(resume_after=resume_token) + await anext(cursor) + # End Changestream Example 3 + await cursor.close() + + # Start Changestream Example 4 + pipeline = [ + {"$match": {"fullDocument.username": "alice"}}, + {"$addFields": {"newField": "this is an added field!"}}, + ] + cursor = await db.inventory.watch(pipeline=pipeline) + await anext(cursor) + # End Changestream Example 4 + await cursor.close() + finally: + done = True + await t.join() + + async def test_aggregate_examples(self): + db = self.db + + # Start Aggregation Example 1 + await db.sales.aggregate([{"$match": {"items.fruit": "banana"}}, {"$sort": {"date": 1}}]) + # End Aggregation Example 1 + + # Start Aggregation Example 2 + await db.sales.aggregate( + [ + {"$unwind": "$items"}, + {"$match": {"items.fruit": "banana"}}, + { + "$group": { + "_id": {"day": {"$dayOfWeek": "$date"}}, + "count": {"$sum": "$items.quantity"}, + } + }, + {"$project": {"dayOfWeek": "$_id.day", "numberSold": "$count", "_id": 0}}, + {"$sort": {"numberSold": 1}}, + ] + ) + # End Aggregation Example 2 + + # Start Aggregation Example 3 + await db.sales.aggregate( + [ + {"$unwind": "$items"}, + { + "$group": { + "_id": {"day": {"$dayOfWeek": "$date"}}, + "items_sold": {"$sum": "$items.quantity"}, + "revenue": {"$sum": {"$multiply": ["$items.quantity", "$items.price"]}}, + } + }, + { + "$project": { + "day": "$_id.day", + "revenue": 1, + "items_sold": 1, + "discount": { + "$cond": {"if": {"$lte": ["$revenue", 250]}, "then": 25, "else": 0} + }, + } + }, + ] + ) + # End Aggregation Example 3 + + # Start Aggregation Example 4 + await db.air_alliances.aggregate( + [ + { + "$lookup": { + "from": "air_airlines", + "let": {"constituents": "$airlines"}, + "pipeline": [{"$match": {"$expr": {"$in": ["$name", "$$constituents"]}}}], + "as": "airlines", + } + }, + { + "$project": { + "_id": 0, + "name": 1, + "airlines": { + "$filter": { + "input": "$airlines", + "as": "airline", + "cond": {"$eq": ["$$airline.country", "Canada"]}, + } + }, + } + }, + ] + ) + # End Aggregation Example 4 + + @async_client_context.require_version_min(4, 4) + async def test_aggregate_projection_example(self): + db = self.db + + # Start Aggregation Projection Example 1 + db.inventory.find( + {}, + { + "_id": 0, + "item": 1, + "status": { + "$switch": { + "branches": [ + {"case": {"$eq": ["$status", "A"]}, "then": "Available"}, + {"case": {"$eq": ["$status", "D"]}, "then": "Discontinued"}, + ], + "default": "No status found", + } + }, + "area": { + "$concat": [ + {"$toString": {"$multiply": ["$size.h", "$size.w"]}}, + " ", + "$size.uom", + ] + }, + "reportNumber": {"$literal": 1}, + }, + ) + + # End Aggregation Projection Example 1 + + async def test_commands(self): + db = self.db + await db.restaurants.insert_one({}) + + # Start runCommand Example 1 + await db.command("buildInfo") + # End runCommand Example 1 + + # Start runCommand Example 2 + await db.command("count", "restaurants") + # End runCommand Example 2 + + async def test_index_management(self): + db = self.db + + # Start Index Example 1 + await db.records.create_index("score") + # End Index Example 1 + + # Start Index Example 1 + await db.restaurants.create_index( + [("cuisine", pymongo.ASCENDING), ("name", pymongo.ASCENDING)], + partialFilterExpression={"rating": {"$gt": 5}}, + ) + # End Index Example 1 + + @async_client_context.require_replica_set + async def test_misc(self): + # Marketing examples + client = self.client + self.addAsyncCleanup(client.drop_database, "test") + self.addAsyncCleanup(client.drop_database, "my_database") + + # 2. Tunable consistency controls + collection = client.my_database.my_collection + async with client.start_session() as session: + await collection.insert_one({"_id": 1}, session=session) + await collection.update_one({"_id": 1}, {"$set": {"a": 1}}, session=session) + async for _doc in collection.find({}, session=session): + pass + + # 3. Exploiting the power of arrays + collection = client.test.array_updates_test + await collection.update_one( + {"_id": 1}, {"$set": {"a.$[i].b": 2}}, array_filters=[{"i.b": 0}] + ) + + +class TestTransactionExamples(AsyncIntegrationTest): + @async_client_context.require_transactions + async def test_transactions(self): + # Transaction examples + client = self.client + self.addAsyncCleanup(client.drop_database, "hr") + self.addAsyncCleanup(client.drop_database, "reporting") + + employees = client.hr.employees + events = client.reporting.events + await employees.insert_one({"employee": 3, "status": "Active"}) + await events.insert_one({"employee": 3, "status": {"new": "Active", "old": None}}) + + # Start Transactions Intro Example 1 + + async def update_employee_info(session): + employees_coll = session.client.hr.employees + events_coll = session.client.reporting.events + + async with await session.start_transaction( + read_concern=ReadConcern("snapshot"), write_concern=WriteConcern(w="majority") + ): + await employees_coll.update_one( + {"employee": 3}, {"$set": {"status": "Inactive"}}, session=session + ) + await events_coll.insert_one( + {"employee": 3, "status": {"new": "Inactive", "old": "Active"}}, session=session + ) + + while True: + try: + # Commit uses write concern set at transaction start. + await session.commit_transaction() + print("Transaction committed.") + break + except (ConnectionFailure, OperationFailure) as exc: + # Can retry commit + if exc.has_error_label("UnknownTransactionCommitResult"): + print("UnknownTransactionCommitResult, retrying commit operation ...") + continue + else: + print("Error during commit ...") + raise + + # End Transactions Intro Example 1 + + async with client.start_session() as session: + await update_employee_info(session) + + employee = await employees.find_one({"employee": 3}) + assert employee is not None + self.assertIsNotNone(employee) + self.assertEqual(employee["status"], "Inactive") + + # Start Transactions Retry Example 1 + async def run_transaction_with_retry(txn_func, session): + while True: + try: + await txn_func(session) # performs transaction + break + except (ConnectionFailure, OperationFailure) as exc: + print("Transaction aborted. Caught exception during transaction.") + + # If transient error, retry the whole transaction + if exc.has_error_label("TransientTransactionError"): + print("TransientTransactionError, retrying transaction ...") + continue + else: + raise + + # End Transactions Retry Example 1 + + async with client.start_session() as session: + await run_transaction_with_retry(update_employee_info, session) + + employee = await employees.find_one({"employee": 3}) + assert employee is not None + self.assertIsNotNone(employee) + self.assertEqual(employee["status"], "Inactive") + + # Start Transactions Retry Example 2 + async def commit_with_retry(session): + while True: + try: + # Commit uses write concern set at transaction start. + await session.commit_transaction() + print("Transaction committed.") + break + except (ConnectionFailure, OperationFailure) as exc: + # Can retry commit + if exc.has_error_label("UnknownTransactionCommitResult"): + print("UnknownTransactionCommitResult, retrying commit operation ...") + continue + else: + print("Error during commit ...") + raise + + # End Transactions Retry Example 2 + + # Test commit_with_retry from the previous examples + async def _insert_employee_retry_commit(session): + async with await session.start_transaction(): + await employees.insert_one({"employee": 4, "status": "Active"}, session=session) + await events.insert_one( + {"employee": 4, "status": {"new": "Active", "old": None}}, session=session + ) + + await commit_with_retry(session) + + async with client.start_session() as session: + await run_transaction_with_retry(_insert_employee_retry_commit, session) + + employee = await employees.find_one({"employee": 4}) + assert employee is not None + self.assertIsNotNone(employee) + self.assertEqual(employee["status"], "Active") + + # Start Transactions Retry Example 3 + + async def run_transaction_with_retry(txn_func, session): + while True: + try: + await txn_func(session) # performs transaction + break + except (ConnectionFailure, OperationFailure) as exc: + # If transient error, retry the whole transaction + if exc.has_error_label("TransientTransactionError"): + print("TransientTransactionError, retrying transaction ...") + continue + else: + raise + + async def commit_with_retry(session): + while True: + try: + # Commit uses write concern set at transaction start. + await session.commit_transaction() + print("Transaction committed.") + break + except (ConnectionFailure, OperationFailure) as exc: + # Can retry commit + if exc.has_error_label("UnknownTransactionCommitResult"): + print("UnknownTransactionCommitResult, retrying commit operation ...") + continue + else: + print("Error during commit ...") + raise + + # Updates two collections in a transactions + + async def update_employee_info(session): + employees_coll = session.client.hr.employees + events_coll = session.client.reporting.events + + async with await session.start_transaction( + read_concern=ReadConcern("snapshot"), + write_concern=WriteConcern(w="majority"), + read_preference=ReadPreference.PRIMARY, + ): + await employees_coll.update_one( + {"employee": 3}, {"$set": {"status": "Inactive"}}, session=session + ) + await events_coll.insert_one( + {"employee": 3, "status": {"new": "Inactive", "old": "Active"}}, session=session + ) + + await commit_with_retry(session) + + # Start a session. + async with client.start_session() as session: + try: + await run_transaction_with_retry(update_employee_info, session) + except Exception: + # Do something with error. + raise + + # End Transactions Retry Example 3 + + employee = await employees.find_one({"employee": 3}) + assert employee is not None + self.assertIsNotNone(employee) + self.assertEqual(employee["status"], "Inactive") + + async def MongoClient(_): + return await self.async_rs_client() + + uriString = None + + # Start Transactions withTxn API Example 1 + + # For a replica set, include the replica set name and a seedlist of the members in the URI string; e.g. + # uriString = 'mongodb://mongodb0.example.com:27017,mongodb1.example.com:27017/?replicaSet=myRepl' + # For a sharded cluster, connect to the mongos instances; e.g. + # uriString = 'mongodb://mongos0.example.com:27017,mongos1.example.com:27017/' + + client = await MongoClient(uriString) + wc_majority = WriteConcern("majority", wtimeout=1000) + + # Prereq: Create collections. + await client.get_database("mydb1", write_concern=wc_majority).foo.insert_one({"abc": 0}) + await client.get_database("mydb2", write_concern=wc_majority).bar.insert_one({"xyz": 0}) + + # Step 1: Define the callback that specifies the sequence of operations to perform inside the transactions. + async def callback(session): + collection_one = session.client.mydb1.foo + collection_two = session.client.mydb2.bar + + # Important:: You must pass the session to the operations. + await collection_one.insert_one({"abc": 1}, session=session) + await collection_two.insert_one({"xyz": 999}, session=session) + + # Step 2: Start a client session. + async with client.start_session() as session: + # Step 3: Use with_transaction to start a transaction, execute the callback, and commit (or abort on error). + await session.with_transaction(callback) + + # End Transactions withTxn API Example 1 + + +class TestCausalConsistencyExamples(AsyncIntegrationTest): + @async_client_context.require_secondaries_count(1) + @async_client_context.require_no_mmap + async def test_causal_consistency(self): + # Causal consistency examples + client = self.client + self.addAsyncCleanup(client.drop_database, "test") + await client.test.drop_collection("items") + await client.test.items.insert_one( + {"sku": "111", "name": "Peanuts", "start": datetime.datetime.today()} + ) + + # Start Causal Consistency Example 1 + async with client.start_session(causal_consistency=True) as s1: + current_date = datetime.datetime.today() + items = client.get_database( + "test", + read_concern=ReadConcern("majority"), + write_concern=WriteConcern("majority", wtimeout=1000), + ).items + await items.update_one( + {"sku": "111", "end": None}, {"$set": {"end": current_date}}, session=s1 + ) + await items.insert_one( + {"sku": "nuts-111", "name": "Pecans", "start": current_date}, session=s1 + ) + # End Causal Consistency Example 1 + + assert s1.cluster_time is not None + assert s1.operation_time is not None + + # Start Causal Consistency Example 2 + async with client.start_session(causal_consistency=True) as s2: + s2.advance_cluster_time(s1.cluster_time) + s2.advance_operation_time(s1.operation_time) + + items = client.get_database( + "test", + read_preference=ReadPreference.SECONDARY, + read_concern=ReadConcern("majority"), + write_concern=WriteConcern("majority", wtimeout=1000), + ).items + async for item in items.find({"end": None}, session=s2): + print(item) + # End Causal Consistency Example 2 + + +class TestVersionedApiExamples(AsyncIntegrationTest): + @async_client_context.require_version_min(4, 7) + async def test_versioned_api(self): + # Versioned API examples + async def MongoClient(_, server_api): + return await self.async_rs_client(server_api=server_api, connect=False) + + uri = None + + # Start Versioned API Example 1 + from pymongo.server_api import ServerApi + + await MongoClient(uri, server_api=ServerApi("1")) + # End Versioned API Example 1 + + # Start Versioned API Example 2 + await MongoClient(uri, server_api=ServerApi("1", strict=True)) + # End Versioned API Example 2 + + # Start Versioned API Example 3 + await MongoClient(uri, server_api=ServerApi("1", strict=False)) + # End Versioned API Example 3 + + # Start Versioned API Example 4 + await MongoClient(uri, server_api=ServerApi("1", deprecation_errors=True)) + # End Versioned API Example 4 + + @unittest.skip("PYTHON-3167 count has been added to API version 1") + @async_client_context.require_version_min(4, 7) + async def test_versioned_api_migration(self): + # SERVER-58785 + if await async_client_context.is_topology_type( + ["sharded"] + ) and not async_client_context.version.at_least(5, 0, 2): + self.skipTest("This test needs MongoDB 5.0.2 or newer") + + client = await self.async_rs_client(server_api=ServerApi("1", strict=True)) + await client.db.sales.drop() + + # Start Versioned API Example 5 + def strptime(s): + return datetime.datetime.strptime(s, "%Y-%m-%dT%H:%M:%SZ") + + await client.db.sales.insert_many( + [ + { + "_id": 1, + "item": "abc", + "price": 10, + "quantity": 2, + "date": strptime("2021-01-01T08:00:00Z"), + }, + { + "_id": 2, + "item": "jkl", + "price": 20, + "quantity": 1, + "date": strptime("2021-02-03T09:00:00Z"), + }, + { + "_id": 3, + "item": "xyz", + "price": 5, + "quantity": 5, + "date": strptime("2021-02-03T09:05:00Z"), + }, + { + "_id": 4, + "item": "abc", + "price": 10, + "quantity": 10, + "date": strptime("2021-02-15T08:00:00Z"), + }, + { + "_id": 5, + "item": "xyz", + "price": 5, + "quantity": 10, + "date": strptime("2021-02-15T09:05:00Z"), + }, + { + "_id": 6, + "item": "xyz", + "price": 5, + "quantity": 5, + "date": strptime("2021-02-15T12:05:10Z"), + }, + { + "_id": 7, + "item": "xyz", + "price": 5, + "quantity": 10, + "date": strptime("2021-02-15T14:12:12Z"), + }, + { + "_id": 8, + "item": "abc", + "price": 10, + "quantity": 5, + "date": strptime("2021-03-16T20:20:13Z"), + }, + ] + ) + # End Versioned API Example 5 + + with self.assertRaisesRegex( + OperationFailure, + "Provided apiStrict:true, but the command count is not in API Version 1", + ): + await client.db.command("count", "sales", query={}) + # Start Versioned API Example 6 + # pymongo.errors.OperationFailure: Provided apiStrict:true, but the command count is not in API Version 1, full error: {'ok': 0.0, 'errmsg': 'Provided apiStrict:true, but the command count is not in API Version 1', 'code': 323, 'codeName': 'APIStrictError'} + # End Versioned API Example 6 + + # Start Versioned API Example 7 + await client.db.sales.count_documents({}) + # End Versioned API Example 7 + + # Start Versioned API Example 8 + # 8 + # End Versioned API Example 8 + + +class TestSnapshotQueryExamples(AsyncIntegrationTest): + @async_client_context.require_version_min(5, 0) + async def test_snapshot_query(self): + client = self.client + + if not await async_client_context.is_topology_type(["replicaset", "sharded"]): + self.skipTest("Must be a sharded or replicaset") + + self.addAsyncCleanup(client.drop_database, "pets") + db = client.pets + await db.drop_collection("cats") + await db.drop_collection("dogs") + await db.cats.insert_one( + {"name": "Whiskers", "color": "white", "age": 10, "adoptable": True} + ) + await db.dogs.insert_one( + {"name": "Pebbles", "color": "Brown", "age": 10, "adoptable": True} + ) + + async def predicate_one(): + return await self.check_for_snapshot(db.cats) + + async def predicate_two(): + return await self.check_for_snapshot(db.dogs) + + await async_wait_until(predicate_two, "success") + await async_wait_until(predicate_one, "success") + + # Start Snapshot Query Example 1 + + db = client.pets + async with client.start_session(snapshot=True) as s: + adoptablePetsCount = ( + await ( + await db.cats.aggregate( + [{"$match": {"adoptable": True}}, {"$count": "adoptableCatsCount"}], + session=s, + ) + ).next() + )["adoptableCatsCount"] + + adoptablePetsCount += ( + await ( + await db.dogs.aggregate( + [{"$match": {"adoptable": True}}, {"$count": "adoptableDogsCount"}], + session=s, + ) + ).next() + )["adoptableDogsCount"] + + print(adoptablePetsCount) + + # End Snapshot Query Example 1 + db = client.retail + self.addAsyncCleanup(client.drop_database, "retail") + await db.drop_collection("sales") + + saleDate = datetime.datetime.now() + await db.sales.insert_one({"shoeType": "boot", "price": 30, "saleDate": saleDate}) + + async def predicate_three(): + return await self.check_for_snapshot(db.sales) + + await async_wait_until(predicate_three, "success") + + # Start Snapshot Query Example 2 + db = client.retail + async with client.start_session(snapshot=True) as s: + _ = ( + await ( + await db.sales.aggregate( + [ + { + "$match": { + "$expr": { + "$gt": [ + "$saleDate", + { + "$dateSubtract": { + "startDate": "$$NOW", + "unit": "day", + "amount": 1, + } + }, + ] + } + } + }, + {"$count": "totalDailySales"}, + ], + session=s, + ) + ).next() + )["totalDailySales"] + + # End Snapshot Query Example 2 + + async def check_for_snapshot(self, collection): + """Wait for snapshot reads to become available to prevent this error: + [246:SnapshotUnavailable]: Unable to read from a snapshot due to pending collection catalog changes; please retry the operation. Snapshot timestamp is Timestamp(1646666892, 4). Collection minimum is Timestamp(1646666892, 5) (on localhost:27017, modern retry, attempt 1) + From https://github.com/mongodb/mongo-ruby-driver/commit/7c4117b58e3d12e237f7536f7521e18fc15f79ac + """ + async with self.client.start_session(snapshot=True) as s: + try: + if await collection.find_one(session=s): + return True + return False + except OperationFailure as e: + # Retry them as the server demands... + if e.code == 246: # SnapshotUnavailable + return False + raise + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_grid_file.py b/test/asynchronous/test_grid_file.py index affdacde91..3f864367de 100644 --- a/test/asynchronous/test_grid_file.py +++ b/test/asynchronous/test_grid_file.py @@ -33,7 +33,7 @@ sys.path[0:0] = [""] -from test.utils import OvertCommandListener +from test.utils_shared import OvertCommandListener from bson.objectid import ObjectId from gridfs.asynchronous.grid_file import ( diff --git a/test/asynchronous/test_gridfs.py b/test/asynchronous/test_gridfs.py new file mode 100644 index 0000000000..f886601f36 --- /dev/null +++ b/test/asynchronous/test_gridfs.py @@ -0,0 +1,603 @@ +# +# Copyright 2009-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the gridfs package.""" +from __future__ import annotations + +import asyncio +import datetime +import sys +import threading +import time +from io import BytesIO +from test.asynchronous.helpers import ConcurrentRunner +from unittest.mock import patch + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.asynchronous.utils import async_joinall +from test.utils_shared import one + +import gridfs +from bson.binary import Binary +from gridfs.asynchronous.grid_file import DEFAULT_CHUNK_SIZE, AsyncGridOutCursor +from gridfs.errors import CorruptGridFile, FileExists, NoFile +from pymongo.asynchronous.database import AsyncDatabase +from pymongo.asynchronous.mongo_client import AsyncMongoClient +from pymongo.errors import ( + ConfigurationError, + NotPrimaryError, + ServerSelectionTimeoutError, +) +from pymongo.read_preferences import ReadPreference + +_IS_SYNC = False + + +class JustWrite(ConcurrentRunner): + def __init__(self, fs, n): + super().__init__() + self.fs = fs + self.n = n + self.daemon = True + + async def run(self): + for _ in range(self.n): + file = self.fs.new_file(filename="test") + await file.write(b"hello") + await file.close() + + +class JustRead(ConcurrentRunner): + def __init__(self, fs, n, results): + super().__init__() + self.fs = fs + self.n = n + self.results = results + self.daemon = True + + async def run(self): + for _ in range(self.n): + file = await self.fs.get("test") + data = await file.read() + self.results.append(data) + assert data == b"hello" + + +class TestGridfsNoConnect(unittest.IsolatedAsyncioTestCase): + db: AsyncDatabase + + async def asyncSetUp(self): + await super().asyncSetUp() + self.db = AsyncMongoClient(connect=False).pymongo_test + + async def test_gridfs(self): + self.assertRaises(TypeError, gridfs.AsyncGridFS, "foo") + self.assertRaises(TypeError, gridfs.AsyncGridFS, self.db, 5) + + +class TestGridfs(AsyncIntegrationTest): + fs: gridfs.AsyncGridFS + alt: gridfs.AsyncGridFS + + async def asyncSetUp(self): + await super().asyncSetUp() + self.fs = gridfs.AsyncGridFS(self.db) + self.alt = gridfs.AsyncGridFS(self.db, "alt") + await self.cleanup_colls( + self.db.fs.files, self.db.fs.chunks, self.db.alt.files, self.db.alt.chunks + ) + + async def test_basic(self): + oid = await self.fs.put(b"hello world") + self.assertEqual(b"hello world", await (await self.fs.get(oid)).read()) + self.assertEqual(1, await self.db.fs.files.count_documents({})) + self.assertEqual(1, await self.db.fs.chunks.count_documents({})) + + await self.fs.delete(oid) + with self.assertRaises(NoFile): + await self.fs.get(oid) + self.assertEqual(0, await self.db.fs.files.count_documents({})) + self.assertEqual(0, await self.db.fs.chunks.count_documents({})) + + with self.assertRaises(NoFile): + await self.fs.get("foo") + oid = await self.fs.put(b"hello world", _id="foo") + self.assertEqual("foo", oid) + self.assertEqual(b"hello world", await (await self.fs.get("foo")).read()) + + async def test_multi_chunk_delete(self): + await self.db.fs.drop() + self.assertEqual(0, await self.db.fs.files.count_documents({})) + self.assertEqual(0, await self.db.fs.chunks.count_documents({})) + gfs = gridfs.AsyncGridFS(self.db) + oid = await gfs.put(b"hello", chunkSize=1) + self.assertEqual(1, await self.db.fs.files.count_documents({})) + self.assertEqual(5, await self.db.fs.chunks.count_documents({})) + await gfs.delete(oid) + self.assertEqual(0, await self.db.fs.files.count_documents({})) + self.assertEqual(0, await self.db.fs.chunks.count_documents({})) + + async def test_list(self): + self.assertEqual([], await self.fs.list()) + await self.fs.put(b"hello world") + self.assertEqual([], await self.fs.list()) + + # PYTHON-598: in server versions before 2.5.x, creating an index on + # filename, uploadDate causes list() to include None. + await self.fs.get_last_version() + self.assertEqual([], await self.fs.list()) + + await self.fs.put(b"", filename="mike") + await self.fs.put(b"foo", filename="test") + await self.fs.put(b"", filename="hello world") + + self.assertEqual({"mike", "test", "hello world"}, set(await self.fs.list())) + + async def test_empty_file(self): + oid = await self.fs.put(b"") + self.assertEqual(b"", await (await self.fs.get(oid)).read()) + self.assertEqual(1, await self.db.fs.files.count_documents({})) + self.assertEqual(0, await self.db.fs.chunks.count_documents({})) + + raw = await self.db.fs.files.find_one() + assert raw is not None + self.assertEqual(0, raw["length"]) + self.assertEqual(oid, raw["_id"]) + self.assertTrue(isinstance(raw["uploadDate"], datetime.datetime)) + self.assertEqual(255 * 1024, raw["chunkSize"]) + self.assertNotIn("md5", raw) + + async def test_corrupt_chunk(self): + files_id = await self.fs.put(b"foobar") + await self.db.fs.chunks.update_one( + {"files_id": files_id}, {"$set": {"data": Binary(b"foo", 0)}} + ) + try: + out = await self.fs.get(files_id) + with self.assertRaises(CorruptGridFile): + await out.read() + + out = await self.fs.get(files_id) + with self.assertRaises(CorruptGridFile): + await out.readline() + finally: + await self.fs.delete(files_id) + + async def test_put_ensures_index(self): + chunks = self.db.fs.chunks + files = self.db.fs.files + # Ensure the collections are removed. + await chunks.drop() + await files.drop() + await self.fs.put(b"junk") + + self.assertTrue( + any( + info.get("key") == [("files_id", 1), ("n", 1)] + for info in (await chunks.index_information()).values() + ) + ) + self.assertTrue( + any( + info.get("key") == [("filename", 1), ("uploadDate", 1)] + for info in (await files.index_information()).values() + ) + ) + + async def test_alt_collection(self): + oid = await self.alt.put(b"hello world") + self.assertEqual(b"hello world", await (await self.alt.get(oid)).read()) + self.assertEqual(1, await self.db.alt.files.count_documents({})) + self.assertEqual(1, await self.db.alt.chunks.count_documents({})) + + await self.alt.delete(oid) + with self.assertRaises(NoFile): + await self.alt.get(oid) + self.assertEqual(0, await self.db.alt.files.count_documents({})) + self.assertEqual(0, await self.db.alt.chunks.count_documents({})) + + with self.assertRaises(NoFile): + await self.alt.get("foo") + oid = await self.alt.put(b"hello world", _id="foo") + self.assertEqual("foo", oid) + self.assertEqual(b"hello world", await (await self.alt.get("foo")).read()) + + await self.alt.put(b"", filename="mike") + await self.alt.put(b"foo", filename="test") + await self.alt.put(b"", filename="hello world") + + self.assertEqual({"mike", "test", "hello world"}, set(await self.alt.list())) + + async def test_threaded_reads(self): + await self.fs.put(b"hello", _id="test") + + tasks = [] + results: list = [] + for i in range(10): + tasks.append(JustRead(self.fs, 10, results)) + await tasks[i].start() + + await async_joinall(tasks) + + self.assertEqual(100 * [b"hello"], results) + + async def test_threaded_writes(self): + tasks = [] + for i in range(10): + tasks.append(JustWrite(self.fs, 10)) + await tasks[i].start() + + await async_joinall(tasks) + + f = await self.fs.get_last_version("test") + self.assertEqual(await f.read(), b"hello") + + # Should have created 100 versions of 'test' file + self.assertEqual(100, await self.db.fs.files.count_documents({"filename": "test"})) + + async def test_get_last_version(self): + one = await self.fs.put(b"foo", filename="test") + await asyncio.sleep(0.01) + two = self.fs.new_file(filename="test") + await two.write(b"bar") + await two.close() + await asyncio.sleep(0.01) + two = two._id + three = await self.fs.put(b"baz", filename="test") + + self.assertEqual(b"baz", await (await self.fs.get_last_version("test")).read()) + await self.fs.delete(three) + self.assertEqual(b"bar", await (await self.fs.get_last_version("test")).read()) + await self.fs.delete(two) + self.assertEqual(b"foo", await (await self.fs.get_last_version("test")).read()) + await self.fs.delete(one) + with self.assertRaises(NoFile): + await self.fs.get_last_version("test") + + async def test_get_last_version_with_metadata(self): + one = await self.fs.put(b"foo", filename="test", author="author") + await asyncio.sleep(0.01) + two = await self.fs.put(b"bar", filename="test", author="author") + + self.assertEqual(b"bar", await (await self.fs.get_last_version(author="author")).read()) + await self.fs.delete(two) + self.assertEqual(b"foo", await (await self.fs.get_last_version(author="author")).read()) + await self.fs.delete(one) + + one = await self.fs.put(b"foo", filename="test", author="author1") + await asyncio.sleep(0.01) + two = await self.fs.put(b"bar", filename="test", author="author2") + + self.assertEqual(b"foo", await (await self.fs.get_last_version(author="author1")).read()) + self.assertEqual(b"bar", await (await self.fs.get_last_version(author="author2")).read()) + self.assertEqual(b"bar", await (await self.fs.get_last_version(filename="test")).read()) + + with self.assertRaises(NoFile): + await self.fs.get_last_version(author="author3") + with self.assertRaises(NoFile): + await self.fs.get_last_version(filename="nottest", author="author1") + + await self.fs.delete(one) + await self.fs.delete(two) + + async def test_get_version(self): + await self.fs.put(b"foo", filename="test") + await asyncio.sleep(0.01) + await self.fs.put(b"bar", filename="test") + await asyncio.sleep(0.01) + await self.fs.put(b"baz", filename="test") + await asyncio.sleep(0.01) + + self.assertEqual(b"foo", await (await self.fs.get_version("test", 0)).read()) + self.assertEqual(b"bar", await (await self.fs.get_version("test", 1)).read()) + self.assertEqual(b"baz", await (await self.fs.get_version("test", 2)).read()) + + self.assertEqual(b"baz", await (await self.fs.get_version("test", -1)).read()) + self.assertEqual(b"bar", await (await self.fs.get_version("test", -2)).read()) + self.assertEqual(b"foo", await (await self.fs.get_version("test", -3)).read()) + + with self.assertRaises(NoFile): + await self.fs.get_version("test", 3) + with self.assertRaises(NoFile): + await self.fs.get_version("test", -4) + + async def test_get_version_with_metadata(self): + one = await self.fs.put(b"foo", filename="test", author="author1") + await asyncio.sleep(0.01) + two = await self.fs.put(b"bar", filename="test", author="author1") + await asyncio.sleep(0.01) + three = await self.fs.put(b"baz", filename="test", author="author2") + + self.assertEqual( + b"foo", + await (await self.fs.get_version(filename="test", author="author1", version=-2)).read(), + ) + self.assertEqual( + b"bar", + await (await self.fs.get_version(filename="test", author="author1", version=-1)).read(), + ) + self.assertEqual( + b"foo", + await (await self.fs.get_version(filename="test", author="author1", version=0)).read(), + ) + self.assertEqual( + b"bar", + await (await self.fs.get_version(filename="test", author="author1", version=1)).read(), + ) + self.assertEqual( + b"baz", + await (await self.fs.get_version(filename="test", author="author2", version=0)).read(), + ) + self.assertEqual( + b"baz", await (await self.fs.get_version(filename="test", version=-1)).read() + ) + self.assertEqual( + b"baz", await (await self.fs.get_version(filename="test", version=2)).read() + ) + + with self.assertRaises(NoFile): + await self.fs.get_version(filename="test", author="author3") + with self.assertRaises(NoFile): + await self.fs.get_version(filename="test", author="author1", version=2) + + await self.fs.delete(one) + await self.fs.delete(two) + await self.fs.delete(three) + + async def test_put_filelike(self): + oid = await self.fs.put(BytesIO(b"hello world"), chunk_size=1) + self.assertEqual(11, await self.db.fs.chunks.count_documents({})) + self.assertEqual(b"hello world", await (await self.fs.get(oid)).read()) + + async def test_file_exists(self): + oid = await self.fs.put(b"hello") + with self.assertRaises(FileExists): + await self.fs.put(b"world", _id=oid) + + one = self.fs.new_file(_id=123) + await one.write(b"some content") + await one.close() + + # Attempt to upload a file with more chunks to the same _id. + with patch("gridfs.asynchronous.grid_file._UPLOAD_BUFFER_SIZE", DEFAULT_CHUNK_SIZE): + two = self.fs.new_file(_id=123) + with self.assertRaises(FileExists): + await two.write(b"x" * DEFAULT_CHUNK_SIZE * 3) + # Original file is still readable (no extra chunks were uploaded). + self.assertEqual(await (await self.fs.get(123)).read(), b"some content") + + two = self.fs.new_file(_id=123) + await two.write(b"some content") + with self.assertRaises(FileExists): + await two.close() + # Original file is still readable. + self.assertEqual(await (await self.fs.get(123)).read(), b"some content") + + async def test_exists(self): + oid = await self.fs.put(b"hello") + self.assertTrue(await self.fs.exists(oid)) + self.assertTrue(await self.fs.exists({"_id": oid})) + self.assertTrue(await self.fs.exists(_id=oid)) + + self.assertFalse(await self.fs.exists(filename="mike")) + self.assertFalse(await self.fs.exists("mike")) + + oid = await self.fs.put(b"hello", filename="mike", foo=12) + self.assertTrue(await self.fs.exists(oid)) + self.assertTrue(await self.fs.exists({"_id": oid})) + self.assertTrue(await self.fs.exists(_id=oid)) + self.assertTrue(await self.fs.exists(filename="mike")) + self.assertTrue(await self.fs.exists({"filename": "mike"})) + self.assertTrue(await self.fs.exists(foo=12)) + self.assertTrue(await self.fs.exists({"foo": 12})) + self.assertTrue(await self.fs.exists(foo={"$gt": 11})) + self.assertTrue(await self.fs.exists({"foo": {"$gt": 11}})) + + self.assertFalse(await self.fs.exists(foo=13)) + self.assertFalse(await self.fs.exists({"foo": 13})) + self.assertFalse(await self.fs.exists(foo={"$gt": 12})) + self.assertFalse(await self.fs.exists({"foo": {"$gt": 12}})) + + async def test_put_unicode(self): + with self.assertRaises(TypeError): + await self.fs.put("hello") + + oid = await self.fs.put("hello", encoding="utf-8") + self.assertEqual(b"hello", await (await self.fs.get(oid)).read()) + self.assertEqual("utf-8", (await self.fs.get(oid)).encoding) + + oid = await self.fs.put("aé", encoding="iso-8859-1") + self.assertEqual("aé".encode("iso-8859-1"), await (await self.fs.get(oid)).read()) + self.assertEqual("iso-8859-1", (await self.fs.get(oid)).encoding) + + async def test_missing_length_iter(self): + # Test fix that guards against PHP-237 + await self.fs.put(b"", filename="empty") + doc = await self.db.fs.files.find_one({"filename": "empty"}) + assert doc is not None + doc.pop("length") + await self.db.fs.files.replace_one({"_id": doc["_id"]}, doc) + f = await self.fs.get_last_version(filename="empty") + + async def iterate_file(grid_file): + async for _chunk in grid_file: + pass + return True + + self.assertTrue(await iterate_file(f)) + + async def test_gridfs_lazy_connect(self): + client = await self.async_single_client( + "badhost", connect=False, serverSelectionTimeoutMS=10 + ) + db = client.db + gfs = gridfs.AsyncGridFS(db) + with self.assertRaises(ServerSelectionTimeoutError): + await gfs.list() + + fs = gridfs.AsyncGridFS(db) + f = fs.new_file() + with self.assertRaises(ServerSelectionTimeoutError): + await f.close() + + async def test_gridfs_find(self): + await self.fs.put(b"test2", filename="two") + await asyncio.sleep(0.01) + await self.fs.put(b"test2+", filename="two") + await asyncio.sleep(0.01) + await self.fs.put(b"test1", filename="one") + await asyncio.sleep(0.01) + await self.fs.put(b"test2++", filename="two") + files = self.db.fs.files + self.assertEqual(3, await files.count_documents({"filename": "two"})) + self.assertEqual(4, await files.count_documents({})) + cursor = self.fs.find(no_cursor_timeout=False).sort("uploadDate", -1).skip(1).limit(2) + gout = await cursor.next() + self.assertEqual(b"test1", await gout.read()) + await cursor.rewind() + gout = await cursor.next() + self.assertEqual(b"test1", await gout.read()) + gout = await cursor.next() + self.assertEqual(b"test2+", await gout.read()) + with self.assertRaises(StopAsyncIteration): + await cursor.__anext__() + await cursor.rewind() + items = await cursor.to_list() + self.assertEqual(len(items), 2) + await cursor.rewind() + items = await cursor.to_list(1) + self.assertEqual(len(items), 1) + await cursor.close() + self.assertRaises(TypeError, self.fs.find, {}, {"_id": True}) + + async def test_delete_not_initialized(self): + # Creating a cursor with invalid arguments will not run __init__ + # but will still call __del__. + cursor = AsyncGridOutCursor.__new__(AsyncGridOutCursor) # Skip calling __init__ + with self.assertRaises(TypeError): + cursor.__init__(self.db.fs.files, {}, {"_id": True}) # type: ignore + cursor.__del__() # no error + + async def test_gridfs_find_one(self): + self.assertEqual(None, await self.fs.find_one()) + + id1 = await self.fs.put(b"test1", filename="file1") + res = await self.fs.find_one() + assert res is not None + self.assertEqual(b"test1", await res.read()) + + id2 = await self.fs.put(b"test2", filename="file2", meta="data") + res1 = await self.fs.find_one(id1) + assert res1 is not None + self.assertEqual(b"test1", await res1.read()) + res2 = await self.fs.find_one(id2) + assert res2 is not None + self.assertEqual(b"test2", await res2.read()) + + res3 = await self.fs.find_one({"filename": "file1"}) + assert res3 is not None + self.assertEqual(b"test1", await res3.read()) + + res4 = await self.fs.find_one(id2) + assert res4 is not None + self.assertEqual("data", res4.meta) + + async def test_grid_in_non_int_chunksize(self): + # Lua, and perhaps other buggy AsyncGridFS clients, store size as a float. + data = b"data" + await self.fs.put(data, filename="f") + await self.db.fs.files.update_one({"filename": "f"}, {"$set": {"chunkSize": 100.0}}) + + self.assertEqual(data, await (await self.fs.get_version("f")).read()) + + async def test_unacknowledged(self): + # w=0 is prohibited. + with self.assertRaises(ConfigurationError): + gridfs.AsyncGridFS((await self.async_rs_or_single_client(w=0)).pymongo_test) + + async def test_md5(self): + gin = self.fs.new_file() + await gin.write(b"no md5 sum") + await gin.close() + self.assertIsNone(gin.md5) + + gout = await self.fs.get(gin._id) + self.assertIsNone(gout.md5) + + _id = await self.fs.put(b"still no md5 sum") + gout = await self.fs.get(_id) + self.assertIsNone(gout.md5) + + +class TestGridfsReplicaSet(AsyncIntegrationTest): + @async_client_context.require_secondaries_count(1) + async def asyncSetUp(self): + await super().asyncSetUp() + + @classmethod + @async_client_context.require_connection + async def asyncTearDownClass(cls): + await async_client_context.client.drop_database("gfsreplica") + + async def test_gridfs_replica_set(self): + rsc = await self.async_rs_client( + w=async_client_context.w, read_preference=ReadPreference.SECONDARY + ) + + fs = gridfs.AsyncGridFS(rsc.gfsreplica, "gfsreplicatest") + + gin = fs.new_file() + self.assertEqual(gin._coll.read_preference, ReadPreference.PRIMARY) + + oid = await fs.put(b"foo") + content = await (await fs.get(oid)).read() + self.assertEqual(b"foo", content) + + async def test_gridfs_secondary(self): + secondary_host, secondary_port = one(await self.client.secondaries) + secondary_connection = await self.async_single_client( + secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY + ) + + # Should detect it's connected to secondary and not attempt to + # create index + fs = gridfs.AsyncGridFS(secondary_connection.gfsreplica, "gfssecondarytest") + + # This won't detect secondary, raises error + with self.assertRaises(NotPrimaryError): + await fs.put(b"foo") + + async def test_gridfs_secondary_lazy(self): + # Should detect it's connected to secondary and not attempt to + # create index. + secondary_host, secondary_port = one(await self.client.secondaries) + client = await self.async_single_client( + secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY, connect=False + ) + + # Still no connection. + fs = gridfs.AsyncGridFS(client.gfsreplica, "gfssecondarylazytest") + + # Connects, doesn't create index. + with self.assertRaises(NoFile): + await fs.get_last_version() + with self.assertRaises(NotPrimaryError): + await fs.put("data", encoding="utf-8") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_gridfs_bucket.py b/test/asynchronous/test_gridfs_bucket.py new file mode 100644 index 0000000000..e8d063b712 --- /dev/null +++ b/test/asynchronous/test_gridfs_bucket.py @@ -0,0 +1,599 @@ +# +# Copyright 2015-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the gridfs package.""" +from __future__ import annotations + +import asyncio +import datetime +import itertools +import sys +import threading +import time +from io import BytesIO +from test.asynchronous.helpers import ConcurrentRunner +from unittest.mock import patch + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.asynchronous.utils import async_joinall +from test.utils_shared import one + +import gridfs +from bson.binary import Binary +from bson.int64 import Int64 +from bson.objectid import ObjectId +from bson.son import SON +from gridfs.errors import CorruptGridFile, NoFile +from pymongo.asynchronous.mongo_client import AsyncMongoClient +from pymongo.errors import ( + ConfigurationError, + NotPrimaryError, + ServerSelectionTimeoutError, + WriteConcernError, +) +from pymongo.read_preferences import ReadPreference + +_IS_SYNC = False + + +class JustWrite(ConcurrentRunner): + def __init__(self, gfs, num): + super().__init__() + self.gfs = gfs + self.num = num + self.daemon = True + + async def run(self): + for _ in range(self.num): + file = self.gfs.open_upload_stream("test") + await file.write(b"hello") + await file.close() + + +class JustRead(ConcurrentRunner): + def __init__(self, gfs, num, results): + super().__init__() + self.gfs = gfs + self.num = num + self.results = results + self.daemon = True + + async def run(self): + for _ in range(self.num): + file = await self.gfs.open_download_stream_by_name("test") + data = await file.read() + self.results.append(data) + assert data == b"hello" + + +class TestGridfs(AsyncIntegrationTest): + fs: gridfs.AsyncGridFSBucket + alt: gridfs.AsyncGridFSBucket + + async def asyncSetUp(self): + await super().asyncSetUp() + self.fs = gridfs.AsyncGridFSBucket(self.db) + self.alt = gridfs.AsyncGridFSBucket(self.db, bucket_name="alt") + await self.cleanup_colls( + self.db.fs.files, self.db.fs.chunks, self.db.alt.files, self.db.alt.chunks + ) + + async def test_basic(self): + oid = await self.fs.upload_from_stream("test_filename", b"hello world") + self.assertEqual(b"hello world", await (await self.fs.open_download_stream(oid)).read()) + self.assertEqual(1, await self.db.fs.files.count_documents({})) + self.assertEqual(1, await self.db.fs.chunks.count_documents({})) + + await self.fs.delete(oid) + with self.assertRaises(NoFile): + await self.fs.open_download_stream(oid) + self.assertEqual(0, await self.db.fs.files.count_documents({})) + self.assertEqual(0, await self.db.fs.chunks.count_documents({})) + + async def test_multi_chunk_delete(self): + self.assertEqual(0, await self.db.fs.files.count_documents({})) + self.assertEqual(0, await self.db.fs.chunks.count_documents({})) + gfs = gridfs.AsyncGridFSBucket(self.db) + oid = await gfs.upload_from_stream("test_filename", b"hello", chunk_size_bytes=1) + self.assertEqual(1, await self.db.fs.files.count_documents({})) + self.assertEqual(5, await self.db.fs.chunks.count_documents({})) + await gfs.delete(oid) + self.assertEqual(0, await self.db.fs.files.count_documents({})) + self.assertEqual(0, await self.db.fs.chunks.count_documents({})) + + async def test_delete_by_name(self): + self.assertEqual(0, await self.db.fs.files.count_documents({})) + self.assertEqual(0, await self.db.fs.chunks.count_documents({})) + gfs = gridfs.AsyncGridFSBucket(self.db) + await gfs.upload_from_stream("test_filename", b"hello", chunk_size_bytes=1) + self.assertEqual(1, await self.db.fs.files.count_documents({})) + self.assertEqual(5, await self.db.fs.chunks.count_documents({})) + await gfs.delete_by_name("test_filename") + self.assertEqual(0, await self.db.fs.files.count_documents({})) + self.assertEqual(0, await self.db.fs.chunks.count_documents({})) + + async def test_empty_file(self): + oid = await self.fs.upload_from_stream("test_filename", b"") + self.assertEqual(b"", await (await self.fs.open_download_stream(oid)).read()) + self.assertEqual(1, await self.db.fs.files.count_documents({})) + self.assertEqual(0, await self.db.fs.chunks.count_documents({})) + + raw = await self.db.fs.files.find_one() + assert raw is not None + self.assertEqual(0, raw["length"]) + self.assertEqual(oid, raw["_id"]) + self.assertTrue(isinstance(raw["uploadDate"], datetime.datetime)) + self.assertEqual(255 * 1024, raw["chunkSize"]) + self.assertNotIn("md5", raw) + + async def test_corrupt_chunk(self): + files_id = await self.fs.upload_from_stream("test_filename", b"foobar") + await self.db.fs.chunks.update_one( + {"files_id": files_id}, {"$set": {"data": Binary(b"foo", 0)}} + ) + try: + out = await self.fs.open_download_stream(files_id) + with self.assertRaises(CorruptGridFile): + await out.read() + + out = await self.fs.open_download_stream(files_id) + with self.assertRaises(CorruptGridFile): + await out.readline() + finally: + await self.fs.delete(files_id) + + async def test_upload_ensures_index(self): + chunks = self.db.fs.chunks + files = self.db.fs.files + # Ensure the collections are removed. + await chunks.drop() + await files.drop() + await self.fs.upload_from_stream("filename", b"junk") + + self.assertTrue( + any( + info.get("key") == [("files_id", 1), ("n", 1)] + for info in (await chunks.index_information()).values() + ) + ) + self.assertTrue( + any( + info.get("key") == [("filename", 1), ("uploadDate", 1)] + for info in (await files.index_information()).values() + ) + ) + + async def test_ensure_index_shell_compat(self): + files = self.db.fs.files + for i, j in itertools.combinations_with_replacement([1, 1.0, Int64(1)], 2): + # Create the index with different numeric types (as might be done + # from the mongo shell). + shell_index = [("filename", i), ("uploadDate", j)] + await self.db.command( + "createIndexes", + files.name, + indexes=[{"key": SON(shell_index), "name": "filename_1.0_uploadDate_1.0"}], + ) + + # No error. + await self.fs.upload_from_stream("filename", b"data") + + self.assertTrue( + any( + info.get("key") == [("filename", 1), ("uploadDate", 1)] + for info in (await files.index_information()).values() + ) + ) + await files.drop() + + async def test_alt_collection(self): + oid = await self.alt.upload_from_stream("test_filename", b"hello world") + self.assertEqual(b"hello world", await (await self.alt.open_download_stream(oid)).read()) + self.assertEqual(1, await self.db.alt.files.count_documents({})) + self.assertEqual(1, await self.db.alt.chunks.count_documents({})) + + await self.alt.delete(oid) + with self.assertRaises(NoFile): + await self.alt.open_download_stream(oid) + self.assertEqual(0, await self.db.alt.files.count_documents({})) + self.assertEqual(0, await self.db.alt.chunks.count_documents({})) + + with self.assertRaises(NoFile): + await self.alt.open_download_stream("foo") + await self.alt.upload_from_stream("foo", b"hello world") + self.assertEqual( + b"hello world", await (await self.alt.open_download_stream_by_name("foo")).read() + ) + + await self.alt.upload_from_stream("mike", b"") + await self.alt.upload_from_stream("test", b"foo") + await self.alt.upload_from_stream("hello world", b"") + + self.assertEqual( + {"mike", "test", "hello world", "foo"}, + {k["filename"] for k in await self.db.alt.files.find().to_list()}, + ) + + async def test_threaded_reads(self): + await self.fs.upload_from_stream("test", b"hello") + + threads = [] + results: list = [] + for i in range(10): + threads.append(JustRead(self.fs, 10, results)) + await threads[i].start() + + await async_joinall(threads) + + self.assertEqual(100 * [b"hello"], results) + + async def test_threaded_writes(self): + threads = [] + for i in range(10): + threads.append(JustWrite(self.fs, 10)) + await threads[i].start() + + await async_joinall(threads) + + fstr = await self.fs.open_download_stream_by_name("test") + self.assertEqual(await fstr.read(), b"hello") + + # Should have created 100 versions of 'test' file + self.assertEqual(100, await self.db.fs.files.count_documents({"filename": "test"})) + + async def test_get_last_version(self): + one = await self.fs.upload_from_stream("test", b"foo") + await asyncio.sleep(0.01) + two = self.fs.open_upload_stream("test") + await two.write(b"bar") + await two.close() + await asyncio.sleep(0.01) + two = two._id + three = await self.fs.upload_from_stream("test", b"baz") + + self.assertEqual(b"baz", await (await self.fs.open_download_stream_by_name("test")).read()) + await self.fs.delete(three) + self.assertEqual(b"bar", await (await self.fs.open_download_stream_by_name("test")).read()) + await self.fs.delete(two) + self.assertEqual(b"foo", await (await self.fs.open_download_stream_by_name("test")).read()) + await self.fs.delete(one) + with self.assertRaises(NoFile): + await self.fs.open_download_stream_by_name("test") + + async def test_get_version(self): + await self.fs.upload_from_stream("test", b"foo") + await asyncio.sleep(0.01) + await self.fs.upload_from_stream("test", b"bar") + await asyncio.sleep(0.01) + await self.fs.upload_from_stream("test", b"baz") + await asyncio.sleep(0.01) + + self.assertEqual( + b"foo", await (await self.fs.open_download_stream_by_name("test", revision=0)).read() + ) + self.assertEqual( + b"bar", await (await self.fs.open_download_stream_by_name("test", revision=1)).read() + ) + self.assertEqual( + b"baz", await (await self.fs.open_download_stream_by_name("test", revision=2)).read() + ) + + self.assertEqual( + b"baz", await (await self.fs.open_download_stream_by_name("test", revision=-1)).read() + ) + self.assertEqual( + b"bar", await (await self.fs.open_download_stream_by_name("test", revision=-2)).read() + ) + self.assertEqual( + b"foo", await (await self.fs.open_download_stream_by_name("test", revision=-3)).read() + ) + + with self.assertRaises(NoFile): + await self.fs.open_download_stream_by_name("test", revision=3) + with self.assertRaises(NoFile): + await self.fs.open_download_stream_by_name("test", revision=-4) + + async def test_upload_from_stream(self): + oid = await self.fs.upload_from_stream( + "test_file", BytesIO(b"hello world"), chunk_size_bytes=1 + ) + self.assertEqual(11, await self.db.fs.chunks.count_documents({})) + self.assertEqual(b"hello world", await (await self.fs.open_download_stream(oid)).read()) + + async def test_upload_from_stream_with_id(self): + oid = ObjectId() + await self.fs.upload_from_stream_with_id( + oid, "test_file_custom_id", BytesIO(b"custom id"), chunk_size_bytes=1 + ) + self.assertEqual(b"custom id", await (await self.fs.open_download_stream(oid)).read()) + + @patch("gridfs.asynchronous.grid_file._UPLOAD_BUFFER_CHUNKS", 3) + @async_client_context.require_failCommand_fail_point + async def test_upload_bulk_write_error(self): + # Test BulkWriteError from insert_many is converted to an insert_one style error. + expected_wce = { + "code": 100, + "codeName": "UnsatisfiableWriteConcern", + "errmsg": "Not enough data-bearing nodes", + } + cause_wce = { + "configureFailPoint": "failCommand", + "mode": {"times": 2}, + "data": {"failCommands": ["insert"], "writeConcernError": expected_wce}, + } + gin = self.fs.open_upload_stream("test_file", chunk_size_bytes=1) + async with self.fail_point(cause_wce): + # Assert we raise WriteConcernError, not BulkWriteError. + with self.assertRaises(WriteConcernError): + await gin.write(b"hello world") + # 3 chunks were uploaded. + self.assertEqual(3, await self.db.fs.chunks.count_documents({"files_id": gin._id})) + await gin.abort() + + @patch("gridfs.asynchronous.grid_file._UPLOAD_BUFFER_CHUNKS", 10) + async def test_upload_batching(self): + async with self.fs.open_upload_stream("test_file", chunk_size_bytes=1) as gin: + await gin.write(b"s" * (10 - 1)) + # No chunks were uploaded yet. + self.assertEqual(0, await self.db.fs.chunks.count_documents({"files_id": gin._id})) + await gin.write(b"s") + # All chunks were uploaded since we hit the _UPLOAD_BUFFER_CHUNKS limit. + self.assertEqual(10, await self.db.fs.chunks.count_documents({"files_id": gin._id})) + + async def test_open_upload_stream(self): + gin = self.fs.open_upload_stream("from_stream") + await gin.write(b"from stream") + await gin.close() + self.assertEqual(b"from stream", await (await self.fs.open_download_stream(gin._id)).read()) + + async def test_open_upload_stream_with_id(self): + oid = ObjectId() + gin = self.fs.open_upload_stream_with_id(oid, "from_stream_custom_id") + await gin.write(b"from stream with custom id") + await gin.close() + self.assertEqual( + b"from stream with custom id", await (await self.fs.open_download_stream(oid)).read() + ) + + async def test_missing_length_iter(self): + # Test fix that guards against PHP-237 + await self.fs.upload_from_stream("empty", b"") + doc = await self.db.fs.files.find_one({"filename": "empty"}) + assert doc is not None + doc.pop("length") + await self.db.fs.files.replace_one({"_id": doc["_id"]}, doc) + fstr = await self.fs.open_download_stream_by_name("empty") + + async def iterate_file(grid_file): + async for _ in grid_file: + pass + return True + + self.assertTrue(await iterate_file(fstr)) + + async def test_gridfs_lazy_connect(self): + client = await self.async_single_client( + "badhost", connect=False, serverSelectionTimeoutMS=0 + ) + cdb = client.db + gfs = gridfs.AsyncGridFSBucket(cdb) + with self.assertRaises(ServerSelectionTimeoutError): + await gfs.delete(0) + + gfs = gridfs.AsyncGridFSBucket(cdb) + with self.assertRaises(ServerSelectionTimeoutError): + await gfs.upload_from_stream("test", b"") # Still no connection. + + async def test_gridfs_find(self): + await self.fs.upload_from_stream("two", b"test2") + await asyncio.sleep(0.01) + await self.fs.upload_from_stream("two", b"test2+") + await asyncio.sleep(0.01) + await self.fs.upload_from_stream("one", b"test1") + await asyncio.sleep(0.01) + await self.fs.upload_from_stream("two", b"test2++") + files = self.db.fs.files + self.assertEqual(3, await files.count_documents({"filename": "two"})) + self.assertEqual(4, await files.count_documents({})) + cursor = self.fs.find( + {}, no_cursor_timeout=False, sort=[("uploadDate", -1)], skip=1, limit=2 + ) + gout = await cursor.next() + self.assertEqual(b"test1", await gout.read()) + await cursor.rewind() + gout = await cursor.next() + self.assertEqual(b"test1", await gout.read()) + gout = await cursor.next() + self.assertEqual(b"test2+", await gout.read()) + with self.assertRaises(StopAsyncIteration): + await cursor.next() + await cursor.close() + self.assertRaises(TypeError, self.fs.find, {}, {"_id": True}) + + async def test_grid_in_non_int_chunksize(self): + # Lua, and perhaps other buggy AsyncGridFS clients, store size as a float. + data = b"data" + await self.fs.upload_from_stream("f", data) + await self.db.fs.files.update_one({"filename": "f"}, {"$set": {"chunkSize": 100.0}}) + + self.assertEqual(data, await (await self.fs.open_download_stream_by_name("f")).read()) + + async def test_unacknowledged(self): + # w=0 is prohibited. + with self.assertRaises(ConfigurationError): + gridfs.AsyncGridFSBucket((await self.async_rs_or_single_client(w=0)).pymongo_test) + + async def test_rename(self): + _id = await self.fs.upload_from_stream("first_name", b"testing") + self.assertEqual( + b"testing", await (await self.fs.open_download_stream_by_name("first_name")).read() + ) + + await self.fs.rename(_id, "second_name") + with self.assertRaises(NoFile): + await self.fs.open_download_stream_by_name("first_name") + self.assertEqual( + b"testing", await (await self.fs.open_download_stream_by_name("second_name")).read() + ) + + async def test_rename_by_name(self): + _id = await self.fs.upload_from_stream("first_name", b"testing") + self.assertEqual( + b"testing", await (await self.fs.open_download_stream_by_name("first_name")).read() + ) + + await self.fs.rename_by_name("first_name", "second_name") + with self.assertRaises(NoFile): + await self.fs.open_download_stream_by_name("first_name") + self.assertEqual( + b"testing", await (await self.fs.open_download_stream_by_name("second_name")).read() + ) + + @patch("gridfs.asynchronous.grid_file._UPLOAD_BUFFER_SIZE", 5) + async def test_abort(self): + gin = self.fs.open_upload_stream("test_filename", chunk_size_bytes=5) + await gin.write(b"test1") + await gin.write(b"test2") + await gin.write(b"test3") + self.assertEqual(3, await self.db.fs.chunks.count_documents({"files_id": gin._id})) + await gin.abort() + self.assertTrue(gin.closed) + with self.assertRaises(ValueError): + await gin.write(b"test4") + self.assertEqual(0, await self.db.fs.chunks.count_documents({"files_id": gin._id})) + + async def test_download_to_stream(self): + file1 = BytesIO(b"hello world") + # Test with one chunk. + oid = await self.fs.upload_from_stream("one_chunk", file1) + self.assertEqual(1, await self.db.fs.chunks.count_documents({})) + file2 = BytesIO() + await self.fs.download_to_stream(oid, file2) + file1.seek(0) + file2.seek(0) + self.assertEqual(file1.read(), file2.read()) + + # Test with many chunks. + await self.db.drop_collection("fs.files") + await self.db.drop_collection("fs.chunks") + file1.seek(0) + oid = await self.fs.upload_from_stream("many_chunks", file1, chunk_size_bytes=1) + self.assertEqual(11, await self.db.fs.chunks.count_documents({})) + file2 = BytesIO() + await self.fs.download_to_stream(oid, file2) + file1.seek(0) + file2.seek(0) + self.assertEqual(file1.read(), file2.read()) + + async def test_download_to_stream_by_name(self): + file1 = BytesIO(b"hello world") + # Test with one chunk. + _ = await self.fs.upload_from_stream("one_chunk", file1) + self.assertEqual(1, await self.db.fs.chunks.count_documents({})) + file2 = BytesIO() + await self.fs.download_to_stream_by_name("one_chunk", file2) + file1.seek(0) + file2.seek(0) + self.assertEqual(file1.read(), file2.read()) + + # Test with many chunks. + await self.db.drop_collection("fs.files") + await self.db.drop_collection("fs.chunks") + file1.seek(0) + await self.fs.upload_from_stream("many_chunks", file1, chunk_size_bytes=1) + self.assertEqual(11, await self.db.fs.chunks.count_documents({})) + + file2 = BytesIO() + await self.fs.download_to_stream_by_name("many_chunks", file2) + file1.seek(0) + file2.seek(0) + self.assertEqual(file1.read(), file2.read()) + + async def test_md5(self): + gin = self.fs.open_upload_stream("no md5") + await gin.write(b"no md5 sum") + await gin.close() + self.assertIsNone(gin.md5) + + gout = await self.fs.open_download_stream(gin._id) + self.assertIsNone(gout.md5) + + gin = self.fs.open_upload_stream_with_id(ObjectId(), "also no md5") + await gin.write(b"also no md5 sum") + await gin.close() + self.assertIsNone(gin.md5) + + gout = await self.fs.open_download_stream(gin._id) + self.assertIsNone(gout.md5) + + +class TestGridfsBucketReplicaSet(AsyncIntegrationTest): + @async_client_context.require_secondaries_count(1) + async def asyncSetUp(self): + await super().asyncSetUp() + + @classmethod + @async_client_context.require_connection + async def asyncTearDownClass(cls): + await async_client_context.client.drop_database("gfsbucketreplica") + + async def test_gridfs_replica_set(self): + rsc = await self.async_rs_client( + w=async_client_context.w, read_preference=ReadPreference.SECONDARY + ) + + gfs = gridfs.AsyncGridFSBucket(rsc.gfsbucketreplica, "gfsbucketreplicatest") + oid = await gfs.upload_from_stream("test_filename", b"foo") + content = await (await gfs.open_download_stream(oid)).read() + self.assertEqual(b"foo", content) + + async def test_gridfs_secondary(self): + secondary_host, secondary_port = one(await self.client.secondaries) + secondary_connection = await self.async_single_client( + secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY + ) + + # Should detect it's connected to secondary and not attempt to + # create index + gfs = gridfs.AsyncGridFSBucket( + secondary_connection.gfsbucketreplica, "gfsbucketsecondarytest" + ) + + # This won't detect secondary, raises error + with self.assertRaises(NotPrimaryError): + await gfs.upload_from_stream("test_filename", b"foo") + + async def test_gridfs_secondary_lazy(self): + # Should detect it's connected to secondary and not attempt to + # create index. + secondary_host, secondary_port = one(await self.client.secondaries) + client = await self.async_single_client( + secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY, connect=False + ) + + # Still no connection. + gfs = gridfs.AsyncGridFSBucket(client.gfsbucketreplica, "gfsbucketsecondarylazytest") + + # Connects, doesn't create index. + with self.assertRaises(NoFile): + await gfs.open_download_stream_by_name("test_filename") + with self.assertRaises(NotPrimaryError): + await gfs.upload_from_stream("test_filename", b"data") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_gridfs_spec.py b/test/asynchronous/test_gridfs_spec.py new file mode 100644 index 0000000000..f3dc14fbdc --- /dev/null +++ b/test/asynchronous/test_gridfs_spec.py @@ -0,0 +1,39 @@ +# Copyright 2015-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the AsyncGridFS unified spec tests.""" +from __future__ import annotations + +import os +import sys +from pathlib import Path + +sys.path[0:0] = [""] + +from test import unittest +from test.asynchronous.unified_format import generate_test_classes + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "gridfs") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "gridfs") + +# Generate unified tests. +globals().update(generate_test_classes(TEST_PATH, module=__name__)) + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_heartbeat_monitoring.py b/test/asynchronous/test_heartbeat_monitoring.py new file mode 100644 index 0000000000..aa8a205021 --- /dev/null +++ b/test/asynchronous/test_heartbeat_monitoring.py @@ -0,0 +1,98 @@ +# Copyright 2016-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the monitoring of the server heartbeats.""" +from __future__ import annotations + +import sys +from test.asynchronous.utils import AsyncMockPool + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, client_knobs, unittest +from test.utils_shared import HeartbeatEventListener, async_wait_until + +from pymongo.asynchronous.monitor import Monitor +from pymongo.errors import ConnectionFailure +from pymongo.hello import Hello, HelloCompat + +_IS_SYNC = False + + +class TestHeartbeatMonitoring(AsyncIntegrationTest): + async def create_mock_monitor(self, responses, uri, expected_results): + listener = HeartbeatEventListener() + with client_knobs( + heartbeat_frequency=0.1, min_heartbeat_interval=0.1, events_queue_frequency=0.1 + ): + + class MockMonitor(Monitor): + async def _check_with_socket(self, *args, **kwargs): + if isinstance(responses[1], Exception): + raise responses[1] + return Hello(responses[1]), 99 + + _ = await self.async_single_client( + h=uri, + event_listeners=(listener,), + _monitor_class=MockMonitor, + _pool_class=AsyncMockPool, + connect=True, + ) + + expected_len = len(expected_results) + # Wait for *at least* expected_len number of results. The + # monitor thread may run multiple times during the execution + # of this test. + await async_wait_until( + lambda: len(listener.events) >= expected_len, "publish all events" + ) + + # zip gives us len(expected_results) pairs. + for expected, actual in zip(expected_results, listener.events): + self.assertEqual(expected, actual.__class__.__name__) + self.assertEqual(actual.connection_id, responses[0]) + if expected != "ServerHeartbeatStartedEvent": + if isinstance(actual.reply, Hello): + self.assertEqual(actual.duration, 99) + self.assertEqual(actual.reply._doc, responses[1]) + else: + self.assertEqual(actual.reply, responses[1]) + + async def test_standalone(self): + responses = ( + ("a", 27017), + {HelloCompat.LEGACY_CMD: True, "maxWireVersion": 4, "minWireVersion": 0, "ok": 1}, + ) + uri = "mongodb://a:27017" + expected_results = ["ServerHeartbeatStartedEvent", "ServerHeartbeatSucceededEvent"] + + await self.create_mock_monitor(responses, uri, expected_results) + + async def test_standalone_error(self): + responses = (("a", 27017), ConnectionFailure("SPECIAL MESSAGE")) + uri = "mongodb://a:27017" + # _check_with_socket failing results in a second attempt. + expected_results = [ + "ServerHeartbeatStartedEvent", + "ServerHeartbeatFailedEvent", + "ServerHeartbeatStartedEvent", + "ServerHeartbeatFailedEvent", + ] + + await self.create_mock_monitor(responses, uri, expected_results) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_index_management.py b/test/asynchronous/test_index_management.py new file mode 100644 index 0000000000..890788fc56 --- /dev/null +++ b/test/asynchronous/test_index_management.py @@ -0,0 +1,379 @@ +# Copyright 2023-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run the auth spec tests.""" +from __future__ import annotations + +import asyncio +import os +import pathlib +import sys +import time +import uuid +from typing import Any, Mapping + +import pytest + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, unittest +from test.asynchronous.unified_format import generate_test_classes +from test.utils_shared import AllowListEventListener, OvertCommandListener + +from pymongo.errors import OperationFailure +from pymongo.operations import SearchIndexModel +from pymongo.read_concern import ReadConcern +from pymongo.write_concern import WriteConcern + +_IS_SYNC = False + +pytestmark = pytest.mark.search_index + +# Location of JSON test specifications. +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "index_management") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "index_management") + +_NAME = "test-search-index" + + +class TestCreateSearchIndex(AsyncIntegrationTest): + async def test_inputs(self): + listener = AllowListEventListener("createSearchIndexes") + client = self.simple_client(event_listeners=[listener]) + coll = client.test.test + await coll.drop() + definition = dict(mappings=dict(dynamic=True)) + model_kwarg_list: list[Mapping[str, Any]] = [ + dict(definition=definition, name=None), + dict(definition=definition, name="test"), + ] + for model_kwargs in model_kwarg_list: + model = SearchIndexModel(**model_kwargs) + with self.assertRaises(OperationFailure): + await coll.create_search_index(model) + with self.assertRaises(OperationFailure): + await coll.create_search_index(model_kwargs) + + listener.reset() + with self.assertRaises(OperationFailure): + await coll.create_search_index({"definition": definition, "arbitraryOption": 1}) + self.assertEqual( + {"definition": definition, "arbitraryOption": 1}, + listener.events[0].command["indexes"][0], + ) + + listener.reset() + with self.assertRaises(OperationFailure): + await coll.create_search_index({"definition": definition, "type": "search"}) + self.assertEqual( + {"definition": definition, "type": "search"}, listener.events[0].command["indexes"][0] + ) + + +class SearchIndexIntegrationBase(AsyncPyMongoTestCase): + db_name = "test_search_index_base" + + @classmethod + def setUpClass(cls) -> None: + cls.url = os.environ.get("MONGODB_URI") + cls.username = os.environ["DB_USER"] + cls.password = os.environ["DB_PASSWORD"] + cls.listener = OvertCommandListener() + + async def asyncSetUp(self) -> None: + self.client = self.simple_client( + self.url, + username=self.username, + password=self.password, + event_listeners=[self.listener], + ) + await self.client.drop_database(_NAME) + self.db = self.client[self.db_name] + + async def asyncTearDown(self): + await self.client.drop_database(_NAME) + + async def wait_for_ready(self, coll, name=_NAME, predicate=None): + """Wait for a search index to be ready.""" + indices: list[Mapping[str, Any]] = [] + if predicate is None: + predicate = lambda index: index.get("queryable") is True + + while True: + indices = await (await coll.list_search_indexes(name)).to_list() + if len(indices) and predicate(indices[0]): + return indices[0] + await asyncio.sleep(5) + + +class TestSearchIndexIntegration(SearchIndexIntegrationBase): + db_name = "test_search_index" + + async def test_comment_field(self): + # Create a collection with the "create" command using a randomly generated name (referred to as ``coll0``). + coll0 = self.db[f"col{uuid.uuid4()}"] + await coll0.insert_one({}) + + # Create a new search index on ``coll0`` that implicitly passes its type. + search_definition = {"mappings": {"dynamic": False}} + self.listener.reset() + implicit_search_resp = await coll0.create_search_index( + model={"name": _NAME + "-implicit", "definition": search_definition}, comment="foo" + ) + event = self.listener.events[0] + self.assertEqual(event.command["comment"], "foo") + + # Get the index definition. + self.listener.reset() + await (await coll0.list_search_indexes(name=implicit_search_resp, comment="foo")).next() + event = self.listener.events[0] + self.assertEqual(event.command["comment"], "foo") + + +class TestSearchIndexProse(SearchIndexIntegrationBase): + db_name = "test_search_index_prose" + + async def test_case_1(self): + """Driver can successfully create and list search indexes.""" + + # Create a collection with the "create" command using a randomly generated name (referred to as ``coll0``). + coll0 = self.db[f"col{uuid.uuid4()}"] + + # Create a new search index on ``coll0`` with the ``createSearchIndex`` helper. Use the following definition: + model = {"name": _NAME, "definition": {"mappings": {"dynamic": False}}} + await coll0.insert_one({}) + resp = await coll0.create_search_index(model) + + # Assert that the command returns the name of the index: ``"test-search-index"``. + self.assertEqual(resp, _NAME) + + # Run ``coll0.listSearchIndexes()`` repeatedly every 5 seconds until the following condition is satisfied and store the value in a variable ``index``: + # An index with the ``name`` of ``test-search-index`` is present and the index has a field ``queryable`` with a value of ``true``. + index = await self.wait_for_ready(coll0) + + # . Assert that ``index`` has a property ``latestDefinition`` whose value is ``{ 'mappings': { 'dynamic': false } }`` + self.assertIn("latestDefinition", index) + self.assertEqual(index["latestDefinition"], model["definition"]) + + async def test_case_2(self): + """Driver can successfully create multiple indexes in batch.""" + + # Create a collection with the "create" command using a randomly generated name (referred to as ``coll0``). + coll0 = self.db[f"col{uuid.uuid4()}"] + await coll0.insert_one({}) + + # Create two new search indexes on ``coll0`` with the ``createSearchIndexes`` helper. + name1 = "test-search-index-1" + name2 = "test-search-index-2" + definition = {"mappings": {"dynamic": False}} + index_definitions: list[dict[str, Any]] = [ + {"name": name1, "definition": definition}, + {"name": name2, "definition": definition}, + ] + await coll0.create_search_indexes( + [SearchIndexModel(i["definition"], i["name"]) for i in index_definitions] + ) + + # .Assert that the command returns an array containing the new indexes' names: ``["test-search-index-1", "test-search-index-2"]``. + indices = await (await coll0.list_search_indexes()).to_list() + names = [i["name"] for i in indices] + self.assertIn(name1, names) + self.assertIn(name2, names) + + # Run ``coll0.listSearchIndexes()`` repeatedly every 5 seconds until the following condition is satisfied. + # An index with the ``name`` of ``test-search-index-1`` is present and index has a field ``queryable`` with the value of ``true``. Store result in ``index1``. + # An index with the ``name`` of ``test-search-index-2`` is present and index has a field ``queryable`` with the value of ``true``. Store result in ``index2``. + index1 = await self.wait_for_ready(coll0, name1) + index2 = await self.wait_for_ready(coll0, name2) + + # Assert that ``index1`` and ``index2`` have the property ``latestDefinition`` whose value is ``{ "mappings" : { "dynamic" : false } }`` + for index in [index1, index2]: + self.assertIn("latestDefinition", index) + self.assertEqual(index["latestDefinition"], definition) + + async def test_case_3(self): + """Driver can successfully drop search indexes.""" + + # Create a collection with the "create" command using a randomly generated name (referred to as ``coll0``). + coll0 = self.db[f"col{uuid.uuid4()}"] + await coll0.insert_one({}) + + # Create a new search index on ``coll0``. + model = {"name": _NAME, "definition": {"mappings": {"dynamic": False}}} + resp = await coll0.create_search_index(model) + + # Assert that the command returns the name of the index: ``"test-search-index"``. + self.assertEqual(resp, "test-search-index") + + # Run ``coll0.listSearchIndexes()`` repeatedly every 5 seconds until the following condition is satisfied: + # An index with the ``name`` of ``test-search-index`` is present and index has a field ``queryable`` with the value of ``true``. + await self.wait_for_ready(coll0) + + # Run a ``dropSearchIndex`` on ``coll0``, using ``test-search-index`` for the name. + await coll0.drop_search_index(_NAME) + + # Run ``coll0.listSearchIndexes()`` repeatedly every 5 seconds until ``listSearchIndexes`` returns an empty array. + t0 = time.time() + while True: + indices = await (await coll0.list_search_indexes()).to_list() + if indices: + break + if (time.time() - t0) / 60 > 5: + raise TimeoutError("Timed out waiting for index deletion") + await asyncio.sleep(5) + + async def test_case_4(self): + """Driver can update a search index.""" + # Create a collection with the "create" command using a randomly generated name (referred to as ``coll0``). + coll0 = self.db[f"col{uuid.uuid4()}"] + await coll0.insert_one({}) + + # Create a new search index on ``coll0``. + model = {"name": _NAME, "definition": {"mappings": {"dynamic": False}}} + resp = await coll0.create_search_index(model) + + # Assert that the command returns the name of the index: ``"test-search-index"``. + self.assertEqual(resp, _NAME) + + # Run ``coll0.listSearchIndexes()`` repeatedly every 5 seconds until the following condition is satisfied: + # An index with the ``name`` of ``test-search-index`` is present and index has a field ``queryable`` with the value of ``true``. + await self.wait_for_ready(coll0) + + # Run a ``updateSearchIndex`` on ``coll0``. + # Assert that the command does not error and the server responds with a success. + model2: dict[str, Any] = {"name": _NAME, "definition": {"mappings": {"dynamic": True}}} + await coll0.update_search_index(_NAME, model2["definition"]) + + # Run ``coll0.listSearchIndexes()`` repeatedly every 5 seconds until the following condition is satisfied: + # An index with the ``name`` of ``test-search-index`` is present. This index is referred to as ``index``. + # The index has a field ``queryable`` with a value of ``true`` and has a field ``status`` with the value of ``READY``. + predicate = lambda index: index.get("queryable") is True and index.get("status") == "READY" + await self.wait_for_ready(coll0, predicate=predicate) + + # Assert that an index is present with the name ``test-search-index`` and the definition has a property ``latestDefinition`` whose value is ``{ 'mappings': { 'dynamic': true } }``. + index = (await (await coll0.list_search_indexes(_NAME)).to_list())[0] + self.assertIn("latestDefinition", index) + self.assertEqual(index["latestDefinition"], model2["definition"]) + + async def test_case_5(self): + """``dropSearchIndex`` suppresses namespace not found errors.""" + # Create a driver-side collection object for a randomly generated collection name. Do not create this collection on the server. + coll0 = self.db[f"col{uuid.uuid4()}"] + + # Run a ``dropSearchIndex`` command and assert that no error is thrown. + await coll0.drop_search_index("foo") + + async def test_case_6(self): + """Driver can successfully create and list search indexes with non-default readConcern and writeConcern.""" + # Create a collection with the "create" command using a randomly generated name (referred to as ``coll0``). + coll0 = self.db[f"col{uuid.uuid4()}"] + await coll0.insert_one({}) + + # Apply a write concern ``WriteConcern(w=1)`` and a read concern with ``ReadConcern(level="majority")`` to ``coll0``. + coll0 = coll0.with_options( + write_concern=WriteConcern(w="1"), read_concern=ReadConcern(level="majority") + ) + + # Create a new search index on ``coll0`` with the ``createSearchIndex`` helper. + name = "test-search-index-case6" + model = {"name": name, "definition": {"mappings": {"dynamic": False}}} + resp = await coll0.create_search_index(model) + + # Assert that the command returns the name of the index: ``"test-search-index-case6"``. + self.assertEqual(resp, name) + + # Run ``coll0.listSearchIndexes()`` repeatedly every 5 seconds until the following condition is satisfied and store the value in a variable ``index``: + # - An index with the ``name`` of ``test-search-index-case6`` is present and the index has a field ``queryable`` with a value of ``true``. + index = await self.wait_for_ready(coll0, name) + + # Assert that ``index`` has a property ``latestDefinition`` whose value is ``{ 'mappings': { 'dynamic': false } }`` + self.assertIn("latestDefinition", index) + self.assertEqual(index["latestDefinition"], model["definition"]) + + async def test_case_7(self): + """Driver handles index types.""" + + # Create a collection with the "create" command using a randomly generated name (referred to as ``coll0``). + coll0 = self.db[f"col{uuid.uuid4()}"] + await coll0.insert_one({}) + + # Use these search and vector search definitions for indexes. + search_definition = {"mappings": {"dynamic": False}} + vector_search_definition = { + "fields": [ + { + "type": "vector", + "path": "plot_embedding", + "numDimensions": 1536, + "similarity": "euclidean", + }, + ] + } + + # Create a new search index on ``coll0`` that implicitly passes its type. + implicit_search_resp = await coll0.create_search_index( + model={"name": _NAME + "-implicit", "definition": search_definition} + ) + + # Get the index definition. + resp = await (await coll0.list_search_indexes(name=implicit_search_resp)).next() + + # Assert that the index model contains the correct index type: ``"search"``. + self.assertEqual(resp["type"], "search") + + # Create a new search index on ``coll0`` that explicitly passes its type. + explicit_search_resp = await coll0.create_search_index( + model={"name": _NAME + "-explicit", "type": "search", "definition": search_definition} + ) + + # Get the index definition. + resp = await (await coll0.list_search_indexes(name=explicit_search_resp)).next() + + # Assert that the index model contains the correct index type: ``"search"``. + self.assertEqual(resp["type"], "search") + + # Create a new vector search index on ``coll0`` that explicitly passes its type. + explicit_vector_resp = await coll0.create_search_index( + model={ + "name": _NAME + "-vector", + "type": "vectorSearch", + "definition": vector_search_definition, + } + ) + + # Get the index definition. + resp = await (await coll0.list_search_indexes(name=explicit_vector_resp)).next() + + # Assert that the index model contains the correct index type: ``"vectorSearch"``. + self.assertEqual(resp["type"], "vectorSearch") + + # Catch the error raised when trying to create a vector search index without specifying the type + with self.assertRaises(OperationFailure) as e: + await coll0.create_search_index( + model={"name": _NAME + "-error", "definition": vector_search_definition} + ) + self.assertIn("Attribute mappings missing.", e.exception.details["errmsg"]) + + +globals().update( + generate_test_classes( + _TEST_PATH, + module=__name__, + ) +) + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_json_util_integration.py b/test/asynchronous/test_json_util_integration.py new file mode 100644 index 0000000000..4c02792d89 --- /dev/null +++ b/test/asynchronous/test_json_util_integration.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from test.asynchronous import AsyncIntegrationTest +from typing import Any, List, MutableMapping + +from bson import Binary, Code, DBRef, ObjectId, json_util +from bson.binary import USER_DEFINED_SUBTYPE + +_IS_SYNC = False + + +class TestJsonUtilRoundtrip(AsyncIntegrationTest): + async def test_cursor(self): + db = self.db + + await db.drop_collection("test") + docs: List[MutableMapping[str, Any]] = [ + {"foo": [1, 2]}, + {"bar": {"hello": "world"}}, + {"code": Code("function x() { return 1; }")}, + {"bin": Binary(b"\x00\x01\x02\x03\x04", USER_DEFINED_SUBTYPE)}, + {"dbref": {"_ref": DBRef("simple", ObjectId("509b8db456c02c5ab7e63c34"))}}, + ] + + await db.test.insert_many(docs) + reloaded_docs = json_util.loads(json_util.dumps(await (db.test.find()).to_list())) + for doc in docs: + self.assertTrue(doc in reloaded_docs) diff --git a/test/asynchronous/test_load_balancer.py b/test/asynchronous/test_load_balancer.py new file mode 100644 index 0000000000..127fdfd24d --- /dev/null +++ b/test/asynchronous/test_load_balancer.py @@ -0,0 +1,199 @@ +# Copyright 2021-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the Load Balancer unified spec tests.""" +from __future__ import annotations + +import asyncio +import gc +import os +import pathlib +import sys +import threading +from asyncio import Event +from test.asynchronous.helpers import ConcurrentRunner, ExceptionCatchingTask +from test.asynchronous.utils import async_get_pool + +import pytest + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.asynchronous.unified_format import generate_test_classes +from test.utils_shared import ( + async_wait_until, + create_async_event, +) + +from pymongo.asynchronous.helpers import anext + +_IS_SYNC = False + +pytestmark = pytest.mark.load_balancer + +# Location of JSON test specifications. +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "load_balancer") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "load_balancer") + +# Generate unified tests. +globals().update(generate_test_classes(_TEST_PATH, module=__name__)) + + +class TestLB(AsyncIntegrationTest): + RUN_ON_LOAD_BALANCER = True + RUN_ON_SERVERLESS = True + + async def test_connections_are_only_returned_once(self): + if "PyPy" in sys.version: + # Tracked in PYTHON-3011 + self.skipTest("Test is flaky on PyPy") + pool = await async_get_pool(self.client) + n_conns = len(pool.conns) + await self.db.test.find_one({}) + self.assertEqual(len(pool.conns), n_conns) + await (await self.db.test.aggregate([{"$limit": 1}])).to_list() + self.assertEqual(len(pool.conns), n_conns) + + @async_client_context.require_load_balancer + async def test_unpin_committed_transaction(self): + client = await self.async_rs_client() + pool = await async_get_pool(client) + coll = client[self.db.name].test + async with client.start_session() as session: + async with await session.start_transaction(): + self.assertEqual(pool.active_sockets, 0) + await coll.insert_one({}, session=session) + self.assertEqual(pool.active_sockets, 1) # Pinned. + self.assertEqual(pool.active_sockets, 1) # Still pinned. + self.assertEqual(pool.active_sockets, 0) # Unpinned. + + @async_client_context.require_failCommand_fail_point + async def test_cursor_gc(self): + async def create_resource(coll): + cursor = coll.find({}, batch_size=3) + await anext(cursor) + return cursor + + await self._test_no_gc_deadlock(create_resource) + + @async_client_context.require_failCommand_fail_point + async def test_command_cursor_gc(self): + async def create_resource(coll): + cursor = await coll.aggregate([], batchSize=3) + await anext(cursor) + return cursor + + await self._test_no_gc_deadlock(create_resource) + + async def _test_no_gc_deadlock(self, create_resource): + client = await self.async_rs_client() + pool = await async_get_pool(client) + coll = client[self.db.name].test + await coll.insert_many([{} for _ in range(10)]) + self.assertEqual(pool.active_sockets, 0) + # Cause the initial find attempt to fail to induce a reference cycle. + args = { + "mode": {"times": 1}, + "data": { + "failCommands": ["find", "aggregate"], + "closeConnection": True, + }, + } + async with self.fail_point(args): + resource = await create_resource(coll) + if async_client_context.load_balancer: + self.assertEqual(pool.active_sockets, 1) # Pinned. + + task = PoolLocker(pool) + await task.start() + self.assertTrue(await task.wait(task.locked, 5), "timed out") + # Garbage collect the resource while the pool is locked to ensure we + # don't deadlock. + del resource + # On PyPy it can take a few rounds to collect the cursor. + for _ in range(3): + gc.collect() + task.unlock.set() + await task.join(5) + self.assertFalse(task.is_alive()) + self.assertIsNone(task.exc) + + await async_wait_until(lambda: pool.active_sockets == 0, "return socket") + # Run another operation to ensure the socket still works. + await coll.delete_many({}) + + @async_client_context.require_transactions + async def test_session_gc(self): + client = await self.async_rs_client() + pool = await async_get_pool(client) + session = client.start_session() + await session.start_transaction() + await client.test_session_gc.test.find_one({}, session=session) + # Cleanup the transaction left open on the server unless we're + # testing serverless which does not support killSessions. + if not async_client_context.serverless: + self.addAsyncCleanup(self.client.admin.command, "killSessions", [session.session_id]) + if async_client_context.load_balancer: + self.assertEqual(pool.active_sockets, 1) # Pinned. + + task = PoolLocker(pool) + await task.start() + self.assertTrue(await task.wait(task.locked, 5), "timed out") + # Garbage collect the session while the pool is locked to ensure we + # don't deadlock. + del session + # On PyPy it can take a few rounds to collect the session. + for _ in range(3): + gc.collect() + task.unlock.set() + await task.join(5) + self.assertFalse(task.is_alive()) + self.assertIsNone(task.exc) + + await async_wait_until(lambda: pool.active_sockets == 0, "return socket") + # Run another operation to ensure the socket still works. + await client[self.db.name].test.delete_many({}) + + +class PoolLocker(ExceptionCatchingTask): + def __init__(self, pool): + super().__init__(target=self.lock_pool) + self.pool = pool + self.daemon = True + self.locked = create_async_event() + self.unlock = create_async_event() + + async def lock_pool(self): + async with self.pool.lock: + self.locked.set() + # Wait for the unlock flag. + unlock_pool = await self.wait(self.unlock, 10) + if not unlock_pool: + raise Exception("timed out waiting for unlock signal: deadlock?") + + async def wait(self, event: Event, timeout: int): + if _IS_SYNC: + return event.wait(timeout) # type: ignore[call-arg] + else: + try: + await asyncio.wait_for(event.wait(), timeout=timeout) + except asyncio.TimeoutError: + return False + return True + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_logger.py b/test/asynchronous/test_logger.py index a2e8b35c5f..d024735fd8 100644 --- a/test/asynchronous/test_logger.py +++ b/test/asynchronous/test_logger.py @@ -15,7 +15,7 @@ import os from test import unittest -from test.asynchronous import AsyncIntegrationTest +from test.asynchronous import AsyncIntegrationTest, async_client_context from unittest.mock import patch from bson import json_util @@ -97,6 +97,49 @@ async def test_logging_without_listeners(self): await c.db.test.insert_one({"x": "1"}) self.assertGreater(len(cm.records), 0) + @async_client_context.require_failCommand_fail_point + async def test_logging_retry_read_attempts(self): + await self.db.test.insert_one({"x": "1"}) + + async with self.fail_point( + { + "mode": {"times": 1}, + "data": { + "failCommands": ["find"], + "errorCode": 10107, + "errorLabels": ["RetryableWriteError"], + }, + } + ): + with self.assertLogs("pymongo.command", level="DEBUG") as cm: + await self.db.test.find_one({"x": "1"}) + + retry_messages = [ + r.getMessage() for r in cm.records if "Retrying read attempt" in r.getMessage() + ] + self.assertEqual(len(retry_messages), 1) + + @async_client_context.require_failCommand_fail_point + @async_client_context.require_retryable_writes + async def test_logging_retry_write_attempts(self): + async with self.fail_point( + { + "mode": {"times": 1}, + "data": { + "errorCode": 10107, + "errorLabels": ["RetryableWriteError"], + "failCommands": ["insert"], + }, + } + ): + with self.assertLogs("pymongo.command", level="DEBUG") as cm: + await self.db.test.insert_one({"x": "1"}) + + retry_messages = [ + r.getMessage() for r in cm.records if "Retrying write attempt" in r.getMessage() + ] + self.assertEqual(len(retry_messages), 1) + if __name__ == "__main__": unittest.main() diff --git a/test/asynchronous/test_max_staleness.py b/test/asynchronous/test_max_staleness.py new file mode 100644 index 0000000000..b6e15f9158 --- /dev/null +++ b/test/asynchronous/test_max_staleness.py @@ -0,0 +1,149 @@ +# Copyright 2016 MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test maxStalenessSeconds support.""" +from __future__ import annotations + +import asyncio +import os +import sys +import time +import warnings +from pathlib import Path + +from pymongo import AsyncMongoClient +from pymongo.operations import _Op + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncPyMongoTestCase, async_client_context, unittest +from test.asynchronous.utils_selection_tests import create_selection_tests + +from pymongo.errors import ConfigurationError +from pymongo.server_selectors import writable_server_selector + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "max_staleness") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "max_staleness") + + +class TestAllScenarios(create_selection_tests(TEST_PATH)): # type: ignore + pass + + +class TestMaxStaleness(AsyncPyMongoTestCase): + async def test_max_staleness(self): + client = self.simple_client() + self.assertEqual(-1, client.read_preference.max_staleness) + + client = self.simple_client("mongodb://a/?readPreference=secondary") + self.assertEqual(-1, client.read_preference.max_staleness) + + # These tests are specified in max-staleness-tests.rst. + with self.assertRaises(ConfigurationError): + # Default read pref "primary" can't be used with max staleness. + self.simple_client("mongodb://a/?maxStalenessSeconds=120") + + with self.assertRaises(ConfigurationError): + # Read pref "primary" can't be used with max staleness. + self.simple_client("mongodb://a/?readPreference=primary&maxStalenessSeconds=120") + + client = self.simple_client("mongodb://host/?maxStalenessSeconds=-1") + self.assertEqual(-1, client.read_preference.max_staleness) + + client = self.simple_client("mongodb://host/?readPreference=primary&maxStalenessSeconds=-1") + self.assertEqual(-1, client.read_preference.max_staleness) + + client = self.simple_client( + "mongodb://host/?readPreference=secondary&maxStalenessSeconds=120" + ) + self.assertEqual(120, client.read_preference.max_staleness) + + client = self.simple_client("mongodb://a/?readPreference=secondary&maxStalenessSeconds=1") + self.assertEqual(1, client.read_preference.max_staleness) + + client = self.simple_client("mongodb://a/?readPreference=secondary&maxStalenessSeconds=-1") + self.assertEqual(-1, client.read_preference.max_staleness) + + client = self.simple_client(maxStalenessSeconds=-1, readPreference="nearest") + self.assertEqual(-1, client.read_preference.max_staleness) + + with self.assertRaises(TypeError): + # Prohibit None. + self.simple_client(maxStalenessSeconds=None, readPreference="nearest") + + async def test_max_staleness_float(self): + with self.assertRaises(TypeError) as ctx: + await self.async_rs_or_single_client(maxStalenessSeconds=1.5, readPreference="nearest") + + self.assertIn("must be an integer", str(ctx.exception)) + + with warnings.catch_warnings(record=True) as ctx: + warnings.simplefilter("always") + client = self.simple_client( + "mongodb://host/?maxStalenessSeconds=1.5&readPreference=nearest" + ) + + # Option was ignored. + self.assertEqual(-1, client.read_preference.max_staleness) + self.assertIn("must be an integer", str(ctx[0])) + + async def test_max_staleness_zero(self): + # Zero is too small. + with self.assertRaises(ValueError) as ctx: + await self.async_rs_or_single_client(maxStalenessSeconds=0, readPreference="nearest") + + self.assertIn("must be a positive integer", str(ctx.exception)) + + with warnings.catch_warnings(record=True) as ctx: + warnings.simplefilter("always") + client = self.simple_client( + "mongodb://host/?maxStalenessSeconds=0&readPreference=nearest" + ) + + # Option was ignored. + self.assertEqual(-1, client.read_preference.max_staleness) + self.assertIn("must be a positive integer", str(ctx[0])) + + @async_client_context.require_replica_set + async def test_last_write_date(self): + # From max-staleness-tests.rst, "Parse lastWriteDate". + client = await self.async_rs_or_single_client(heartbeatFrequencyMS=500) + await client.pymongo_test.test.insert_one({}) + # Wait for the server description to be updated. + await asyncio.sleep(1) + server = await client._topology.select_server(writable_server_selector, _Op.TEST) + first = server.description.last_write_date + self.assertTrue(first) + # The first last_write_date may correspond to a internal server write, + # sleep so that the next write does not occur within the same second. + await asyncio.sleep(1) + await client.pymongo_test.test.insert_one({}) + # Wait for the server description to be updated. + await asyncio.sleep(1) + server = await client._topology.select_server(writable_server_selector, _Op.TEST) + second = server.description.last_write_date + assert first is not None + + assert second is not None + self.assertGreater(second, first) + self.assertLess(second, first + 10) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_mongos_load_balancing.py b/test/asynchronous/test_mongos_load_balancing.py new file mode 100644 index 0000000000..97170aa9e0 --- /dev/null +++ b/test/asynchronous/test_mongos_load_balancing.py @@ -0,0 +1,199 @@ +# Copyright 2015-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test AsyncMongoClient's mongos load balancing using a mock.""" +from __future__ import annotations + +import asyncio +import sys +import threading +from test.asynchronous.helpers import ConcurrentRunner + +from pymongo.operations import _Op + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncMockClientTest, async_client_context, connected, unittest +from test.asynchronous.pymongo_mocks import AsyncMockClient +from test.utils_shared import async_wait_until + +from pymongo.errors import AutoReconnect, InvalidOperation +from pymongo.server_selectors import writable_server_selector +from pymongo.topology_description import TOPOLOGY_TYPE + +_IS_SYNC = False + + +class SimpleOp(ConcurrentRunner): + def __init__(self, client): + super().__init__() + self.client = client + self.passed = False + + async def run(self): + await self.client.db.command("ping") + self.passed = True # No exception raised. + + +async def do_simple_op(client, ntasks): + tasks = [SimpleOp(client) for _ in range(ntasks)] + for t in tasks: + await t.start() + + for t in tasks: + await t.join() + + for t in tasks: + assert t.passed + + +async def writable_addresses(topology): + return { + server.description.address + for server in await topology.select_servers(writable_server_selector, _Op.TEST) + } + + +class TestMongosLoadBalancing(AsyncMockClientTest): + @async_client_context.require_connection + @async_client_context.require_no_load_balancer + async def asyncSetUp(self): + await super().asyncSetUp() + + def mock_client(self, **kwargs): + mock_client = AsyncMockClient( + standalones=[], + members=[], + mongoses=["a:1", "b:2", "c:3"], + host="a:1,b:2,c:3", + connect=False, + **kwargs, + ) + self.addAsyncCleanup(mock_client.aclose) + + # Latencies in seconds. + mock_client.mock_rtts["a:1"] = 0.020 + mock_client.mock_rtts["b:2"] = 0.025 + mock_client.mock_rtts["c:3"] = 0.045 + return mock_client + + async def test_lazy_connect(self): + # While connected() ensures we can trigger connection from the main + # thread and wait for the monitors, this test triggers connection from + # several threads at once to check for data races. + nthreads = 10 + client = self.mock_client() + self.assertEqual(0, len(client.nodes)) + + # Trigger initial connection. + await do_simple_op(client, nthreads) + await async_wait_until(lambda: len(client.nodes) == 3, "connect to all mongoses") + + async def test_failover(self): + ntasks = 10 + client = await connected(self.mock_client(localThresholdMS=0.001)) + await async_wait_until(lambda: len(client.nodes) == 3, "connect to all mongoses") + + # Our chosen mongos goes down. + client.kill_host("a:1") + + # Trigger failover to higher-latency nodes. AutoReconnect should be + # raised at most once in each thread. + passed = [] + + async def f(): + try: + await client.db.command("ping") + except AutoReconnect: + # Second attempt succeeds. + await client.db.command("ping") + + passed.append(True) + + tasks = [ConcurrentRunner(target=f) for _ in range(ntasks)] + for t in tasks: + await t.start() + + for t in tasks: + await t.join() + + self.assertEqual(ntasks, len(passed)) + + # Down host removed from list. + self.assertEqual(2, len(client.nodes)) + + async def test_local_threshold(self): + client = await connected(self.mock_client(localThresholdMS=30)) + self.assertEqual(30, client.options.local_threshold_ms) + await async_wait_until(lambda: len(client.nodes) == 3, "connect to all mongoses") + topology = client._topology + + # All are within a 30-ms latency window, see self.mock_client(). + self.assertEqual({("a", 1), ("b", 2), ("c", 3)}, await writable_addresses(topology)) + + # No error + await client.admin.command("ping") + + client = await connected(self.mock_client(localThresholdMS=0)) + self.assertEqual(0, client.options.local_threshold_ms) + # No error + await client.db.command("ping") + # Our chosen mongos goes down. + client.kill_host("{}:{}".format(*next(iter(client.nodes)))) + try: + await client.db.command("ping") + except: + pass + + # We eventually connect to a new mongos. + async def connect_to_new_mongos(): + try: + return await client.db.command("ping") + except AutoReconnect: + pass + + await async_wait_until(connect_to_new_mongos, "connect to a new mongos") + + async def test_load_balancing(self): + # Although the server selection JSON tests already prove that + # select_servers works for sharded topologies, here we do an end-to-end + # test of discovering servers' round trip times and configuring + # localThresholdMS. + client = await connected(self.mock_client()) + await async_wait_until(lambda: len(client.nodes) == 3, "connect to all mongoses") + + # Prohibited for topology type Sharded. + with self.assertRaises(InvalidOperation): + await client.address + + topology = client._topology + self.assertEqual(TOPOLOGY_TYPE.Sharded, topology.description.topology_type) + + # a and b are within the 15-ms latency window, see self.mock_client(). + self.assertEqual({("a", 1), ("b", 2)}, await writable_addresses(topology)) + + client.mock_rtts["a:1"] = 0.045 + + # Discover only b is within latency window. + async def predicate(): + return {("b", 2)} == await writable_addresses(topology) + + await async_wait_until( + predicate, + 'discover server "a" is too far', + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_monitor.py b/test/asynchronous/test_monitor.py new file mode 100644 index 0000000000..195f6f9fac --- /dev/null +++ b/test/asynchronous/test_monitor.py @@ -0,0 +1,121 @@ +# Copyright 2014-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the monitor module.""" +from __future__ import annotations + +import asyncio +import gc +import subprocess +import sys +import warnings +from functools import partial + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, connected, unittest +from test.asynchronous.utils import ( + async_wait_until, +) +from test.utils_shared import ServerAndTopologyEventListener + +from pymongo.periodic_executor import _EXECUTORS + +_IS_SYNC = False + + +def unregistered(ref): + gc.collect() + return ref not in _EXECUTORS + + +def get_executors(client): + executors = [] + for server in client._topology._servers.values(): + executors.append(server._monitor._executor) + executors.append(server._monitor._rtt_monitor._executor) + executors.append(client._kill_cursors_executor) + executors.append(client._topology._Topology__events_executor) + return [e for e in executors if e is not None] + + +class TestMonitor(AsyncIntegrationTest): + async def create_client(self): + listener = ServerAndTopologyEventListener() + client = await self.unmanaged_async_single_client(event_listeners=[listener]) + await connected(client) + return client + + async def test_cleanup_executors_on_client_del(self): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + client = await self.create_client() + executors = get_executors(client) + self.assertEqual(len(executors), 4) + + # Each executor stores a weakref to itself in _EXECUTORS. + executor_refs = [(r, r()._name) for r in _EXECUTORS.copy() if r() in executors] + + del executors + del client + + for ref, name in executor_refs: + await async_wait_until( + partial(unregistered, ref), f"unregister executor: {name}", timeout=5 + ) + + def resource_warning_caught(): + gc.collect() + for warning in w: + if ( + issubclass(warning.category, ResourceWarning) + and "Call AsyncMongoClient.close() to safely shut down your client and free up resources." + in str(warning.message) + ): + return True + return False + + await async_wait_until(resource_warning_caught, "catch resource warning") + + async def test_cleanup_executors_on_client_close(self): + client = await self.create_client() + executors = get_executors(client) + self.assertEqual(len(executors), 4) + + await client.close() + + for executor in executors: + await async_wait_until( + lambda: executor._stopped, f"closed executor: {executor._name}", timeout=5 + ) + + @async_client_context.require_sync + def test_no_thread_start_runtime_err_on_shutdown(self): + """Test we silence noisy runtime errors fired when the AsyncMongoClient spawns a new thread + on process shutdown.""" + command = [ + sys.executable, + "-c", + "from pymongo import AsyncMongoClient; c = AsyncMongoClient()", + ] + completed_process: subprocess.CompletedProcess = subprocess.run( + command, capture_output=True + ) + + self.assertFalse(completed_process.stderr) + self.assertFalse(completed_process.stdout) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_monitoring.py b/test/asynchronous/test_monitoring.py index eaad60beac..a7d56a8cf7 100644 --- a/test/asynchronous/test_monitoring.py +++ b/test/asynchronous/test_monitoring.py @@ -29,7 +29,7 @@ sanitize_cmd, unittest, ) -from test.utils import ( +from test.utils_shared import ( EventListener, OvertCommandListener, async_wait_until, diff --git a/test/asynchronous/test_on_demand_csfle.py b/test/asynchronous/test_on_demand_csfle.py new file mode 100644 index 0000000000..55394ddeb8 --- /dev/null +++ b/test/asynchronous/test_on_demand_csfle.py @@ -0,0 +1,115 @@ +# Copyright 2022-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test client side encryption with on demand credentials.""" +from __future__ import annotations + +import os +import sys +import unittest + +import pytest + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context + +from bson.codec_options import CodecOptions +from pymongo.asynchronous.encryption import ( + _HAVE_PYMONGOCRYPT, + AsyncClientEncryption, + EncryptionError, +) + +_IS_SYNC = False + +pytestmark = pytest.mark.kms + + +class TestonDemandGCPCredentials(AsyncIntegrationTest): + @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") + @async_client_context.require_version_min(4, 2, -1) + async def asyncSetUp(self): + await super().asyncSetUp() + self.master_key = { + "projectId": "devprod-drivers", + "location": "global", + "keyRing": "key-ring-csfle", + "keyName": "key-name-csfle", + } + + @unittest.skipIf(not os.getenv("TEST_FLE_GCP_AUTO"), "Not testing FLE GCP auto") + async def test_01_failure(self): + if os.environ["SUCCESS"].lower() == "true": + self.skipTest("Expecting success") + self.client_encryption = AsyncClientEncryption( + kms_providers={"gcp": {}}, + key_vault_namespace="keyvault.datakeys", + key_vault_client=async_client_context.client, + codec_options=CodecOptions(), + ) + with self.assertRaises(EncryptionError): + await self.client_encryption.create_data_key("gcp", self.master_key) + + @unittest.skipIf(not os.getenv("TEST_FLE_GCP_AUTO"), "Not testing FLE GCP auto") + async def test_02_success(self): + if os.environ["SUCCESS"].lower() == "false": + self.skipTest("Expecting failure") + self.client_encryption = AsyncClientEncryption( + kms_providers={"gcp": {}}, + key_vault_namespace="keyvault.datakeys", + key_vault_client=async_client_context.client, + codec_options=CodecOptions(), + ) + await self.client_encryption.create_data_key("gcp", self.master_key) + + +class TestonDemandAzureCredentials(AsyncIntegrationTest): + @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") + @async_client_context.require_version_min(4, 2, -1) + async def asyncSetUp(self): + await super().asyncSetUp() + self.master_key = { + "keyVaultEndpoint": os.environ["KEY_VAULT_ENDPOINT"], + "keyName": os.environ["KEY_NAME"], + } + + @unittest.skipIf(not os.getenv("TEST_FLE_AZURE_AUTO"), "Not testing FLE Azure auto") + async def test_01_failure(self): + if os.environ["SUCCESS"].lower() == "true": + self.skipTest("Expecting success") + self.client_encryption = AsyncClientEncryption( + kms_providers={"azure": {}}, + key_vault_namespace="keyvault.datakeys", + key_vault_client=async_client_context.client, + codec_options=CodecOptions(), + ) + with self.assertRaises(EncryptionError): + await self.client_encryption.create_data_key("azure", self.master_key) + + @unittest.skipIf(not os.getenv("TEST_FLE_AZURE_AUTO"), "Not testing FLE Azure auto") + async def test_02_success(self): + if os.environ["SUCCESS"].lower() == "false": + self.skipTest("Expecting failure") + self.client_encryption = AsyncClientEncryption( + kms_providers={"azure": {}}, + key_vault_namespace="keyvault.datakeys", + key_vault_client=async_client_context.client, + codec_options=CodecOptions(), + ) + await self.client_encryption.create_data_key("azure", self.master_key) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/test/asynchronous/test_pooling.py b/test/asynchronous/test_pooling.py new file mode 100644 index 0000000000..64c5738dba --- /dev/null +++ b/test/asynchronous/test_pooling.py @@ -0,0 +1,613 @@ +# Copyright 2009-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test built in connection-pooling with threads.""" +from __future__ import annotations + +import asyncio +import gc +import random +import socket +import sys +import time +from test.asynchronous.utils import async_get_pool, async_joinall + +from bson.codec_options import DEFAULT_CODEC_OPTIONS +from bson.son import SON +from pymongo import AsyncMongoClient, message, timeout +from pymongo.errors import AutoReconnect, ConnectionFailure, DuplicateKeyError +from pymongo.hello import HelloCompat +from pymongo.lock import _async_create_lock + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.asynchronous.helpers import ConcurrentRunner +from test.utils_shared import delay + +from pymongo.asynchronous.pool import Pool, PoolOptions +from pymongo.socket_checker import SocketChecker + +_IS_SYNC = False + + +N = 10 +DB = "pymongo-pooling-tests" + + +async def gc_collect_until_done(tasks, timeout=60): + start = time.time() + running = list(tasks) + while running: + assert (time.time() - start) < timeout, "Tasks timed out" + for t in running: + await t.join(0.1) + if not t.is_alive(): + running.remove(t) + gc.collect() + + +class MongoTask(ConcurrentRunner): + """A thread/Task that uses a AsyncMongoClient.""" + + def __init__(self, client): + super().__init__() + self.daemon = True # Don't hang whole test if task hangs. + self.client = client + self.db = self.client[DB] + self.passed = False + + async def run(self): + await self.run_mongo_thread() + self.passed = True + + async def run_mongo_thread(self): + raise NotImplementedError + + +class InsertOneAndFind(MongoTask): + async def run_mongo_thread(self): + for _ in range(N): + rand = random.randint(0, N) + _id = (await self.db.sf.insert_one({"x": rand})).inserted_id + assert rand == (await self.db.sf.find_one(_id))["x"] + + +class Unique(MongoTask): + async def run_mongo_thread(self): + for _ in range(N): + await self.db.unique.insert_one({}) # no error + + +class NonUnique(MongoTask): + async def run_mongo_thread(self): + for _ in range(N): + try: + await self.db.unique.insert_one({"_id": "jesse"}) + except DuplicateKeyError: + pass + else: + raise AssertionError("Should have raised DuplicateKeyError") + + +class SocketGetter(MongoTask): + """Utility for TestPooling. + + Checks out a socket and holds it forever. Used in + test_no_wait_queue_timeout. + """ + + def __init__(self, client, pool): + super().__init__(client) + self.state = "init" + self.pool = pool + self.sock = None + + async def run_mongo_thread(self): + self.state = "get_socket" + + # Call 'pin_cursor' so we can hold the socket. + async with self.pool.checkout() as sock: + sock.pin_cursor() + self.sock = sock + + self.state = "connection" + + async def release_conn(self): + if self.sock: + await self.sock.unpin() + self.sock = None + return True + return False + + +async def run_cases(client, cases): + tasks = [] + n_runs = 5 + + for case in cases: + for _i in range(n_runs): + t = case(client) + await t.start() + tasks.append(t) + + for t in tasks: + await t.join() + + for t in tasks: + assert t.passed, "%s.run() threw an exception" % repr(t) + + +class _TestPoolingBase(AsyncIntegrationTest): + """Base class for all connection-pool tests.""" + + @async_client_context.require_connection + async def asyncSetUp(self): + await super().asyncSetUp() + self.c = await self.async_rs_or_single_client() + db = self.c[DB] + await db.unique.drop() + await db.test.drop() + await db.unique.insert_one({"_id": "jesse"}) + await db.test.insert_many([{} for _ in range(10)]) + + async def create_pool(self, pair=None, *args, **kwargs): + if pair is None: + pair = (await async_client_context.host, await async_client_context.port) + # Start the pool with the correct ssl options. + pool_options = async_client_context.client._topology_settings.pool_options + kwargs["ssl_context"] = pool_options._ssl_context + kwargs["tls_allow_invalid_hostnames"] = pool_options.tls_allow_invalid_hostnames + kwargs["server_api"] = pool_options.server_api + pool = Pool(pair, PoolOptions(*args, **kwargs)) + await pool.ready() + return pool + + +class TestPooling(_TestPoolingBase): + async def test_max_pool_size_validation(self): + host, port = await async_client_context.host, await async_client_context.port + self.assertRaises(ValueError, AsyncMongoClient, host=host, port=port, maxPoolSize=-1) + + self.assertRaises(ValueError, AsyncMongoClient, host=host, port=port, maxPoolSize="foo") + + c = AsyncMongoClient(host=host, port=port, maxPoolSize=100, connect=False) + self.assertEqual(c.options.pool_options.max_pool_size, 100) + + async def test_no_disconnect(self): + await run_cases(self.c, [NonUnique, Unique, InsertOneAndFind]) + + async def test_pool_reuses_open_socket(self): + # Test Pool's _check_closed() method doesn't close a healthy socket. + cx_pool = await self.create_pool(max_pool_size=10) + cx_pool._check_interval_seconds = 0 # Always check. + async with cx_pool.checkout() as conn: + pass + + async with cx_pool.checkout() as new_connection: + self.assertEqual(conn, new_connection) + + self.assertEqual(1, len(cx_pool.conns)) + + async def test_get_socket_and_exception(self): + # get_socket() returns socket after a non-network error. + cx_pool = await self.create_pool(max_pool_size=1, wait_queue_timeout=1) + with self.assertRaises(ZeroDivisionError): + async with cx_pool.checkout() as conn: + 1 / 0 + + # Socket was returned, not closed. + async with cx_pool.checkout() as new_connection: + self.assertEqual(conn, new_connection) + + self.assertEqual(1, len(cx_pool.conns)) + + async def test_pool_removes_closed_socket(self): + # Test that Pool removes explicitly closed socket. + cx_pool = await self.create_pool() + + async with cx_pool.checkout() as conn: + # Use Connection's API to close the socket. + await conn.close_conn(None) + + self.assertEqual(0, len(cx_pool.conns)) + + async def test_pool_removes_dead_socket(self): + # Test that Pool removes dead socket and the socket doesn't return + # itself PYTHON-344 + cx_pool = await self.create_pool(max_pool_size=1, wait_queue_timeout=1) + cx_pool._check_interval_seconds = 0 # Always check. + + async with cx_pool.checkout() as conn: + # Simulate a closed socket without telling the Connection it's + # closed. + await conn.conn.close() + self.assertTrue(conn.conn_closed()) + + async with cx_pool.checkout() as new_connection: + self.assertEqual(0, len(cx_pool.conns)) + self.assertNotEqual(conn, new_connection) + + self.assertEqual(1, len(cx_pool.conns)) + + # Semaphore was released. + async with cx_pool.checkout(): + pass + + async def test_socket_closed(self): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.connect((await async_client_context.host, await async_client_context.port)) + socket_checker = SocketChecker() + self.assertFalse(socket_checker.socket_closed(s)) + s.close() + self.assertTrue(socket_checker.socket_closed(s)) + + async def test_socket_checker(self): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.connect((await async_client_context.host, await async_client_context.port)) + socket_checker = SocketChecker() + # Socket has nothing to read. + self.assertFalse(socket_checker.select(s, read=True)) + self.assertFalse(socket_checker.select(s, read=True, timeout=0)) + self.assertFalse(socket_checker.select(s, read=True, timeout=0.05)) + # Socket is writable. + self.assertTrue(socket_checker.select(s, write=True, timeout=None)) + self.assertTrue(socket_checker.select(s, write=True)) + self.assertTrue(socket_checker.select(s, write=True, timeout=0)) + self.assertTrue(socket_checker.select(s, write=True, timeout=0.05)) + # Make the socket readable + _, msg, _ = message._query( + 0, "admin.$cmd", 0, -1, SON([("ping", 1)]), None, DEFAULT_CODEC_OPTIONS + ) + s.sendall(msg) + # Block until the socket is readable. + self.assertTrue(socket_checker.select(s, read=True, timeout=None)) + self.assertTrue(socket_checker.select(s, read=True)) + self.assertTrue(socket_checker.select(s, read=True, timeout=0)) + self.assertTrue(socket_checker.select(s, read=True, timeout=0.05)) + # Socket is still writable. + self.assertTrue(socket_checker.select(s, write=True, timeout=None)) + self.assertTrue(socket_checker.select(s, write=True)) + self.assertTrue(socket_checker.select(s, write=True, timeout=0)) + self.assertTrue(socket_checker.select(s, write=True, timeout=0.05)) + s.close() + self.assertTrue(socket_checker.socket_closed(s)) + + async def test_return_socket_after_reset(self): + pool = await self.create_pool() + async with pool.checkout() as sock: + self.assertEqual(pool.active_sockets, 1) + self.assertEqual(pool.operation_count, 1) + await pool.reset() + + self.assertTrue(sock.closed) + self.assertEqual(0, len(pool.conns)) + self.assertEqual(pool.active_sockets, 0) + self.assertEqual(pool.operation_count, 0) + + async def test_pool_check(self): + # Test that Pool recovers from two connection failures in a row. + # This exercises code at the end of Pool._check(). + cx_pool = await self.create_pool(max_pool_size=1, connect_timeout=1, wait_queue_timeout=1) + cx_pool._check_interval_seconds = 0 # Always check. + self.addAsyncCleanup(cx_pool.close) + + async with cx_pool.checkout() as conn: + # Simulate a closed socket without telling the Connection it's + # closed. + await conn.conn.close() + + # Swap pool's address with a bad one. + address, cx_pool.address = cx_pool.address, ("foo.com", 1234) + with self.assertRaises(AutoReconnect): + async with cx_pool.checkout(): + pass + + # Back to normal, semaphore was correctly released. + cx_pool.address = address + async with cx_pool.checkout(): + pass + + async def test_wait_queue_timeout(self): + wait_queue_timeout = 2 # Seconds + pool = await self.create_pool(max_pool_size=1, wait_queue_timeout=wait_queue_timeout) + self.addAsyncCleanup(pool.close) + + async with pool.checkout(): + start = time.time() + with self.assertRaises(ConnectionFailure): + async with pool.checkout(): + pass + + duration = time.time() - start + self.assertTrue( + abs(wait_queue_timeout - duration) < 1, + f"Waited {duration:.2f} seconds for a socket, expected {wait_queue_timeout:f}", + ) + + async def test_no_wait_queue_timeout(self): + # Verify get_socket() with no wait_queue_timeout blocks forever. + pool = await self.create_pool(max_pool_size=1) + self.addAsyncCleanup(pool.close) + + # Reach max_size. + async with pool.checkout() as s1: + t = SocketGetter(self.c, pool) + await t.start() + while t.state != "get_socket": + await asyncio.sleep(0.1) + + await asyncio.sleep(1) + self.assertEqual(t.state, "get_socket") + + while t.state != "connection": + await asyncio.sleep(0.1) + + self.assertEqual(t.state, "connection") + self.assertEqual(t.sock, s1) + # Cleanup + await t.release_conn() + await t.join() + await pool.close() + + async def test_checkout_more_than_max_pool_size(self): + pool = await self.create_pool(max_pool_size=2) + + socks = [] + for _ in range(2): + # Call 'pin_cursor' so we can hold the socket. + async with pool.checkout() as sock: + sock.pin_cursor() + socks.append(sock) + + tasks = [] + for _ in range(10): + t = SocketGetter(self.c, pool) + await t.start() + tasks.append(t) + await asyncio.sleep(1) + for t in tasks: + self.assertEqual(t.state, "get_socket") + # Cleanup + for socket_info in socks: + await socket_info.unpin() + while tasks: + to_remove = [] + for t in tasks: + if await t.release_conn(): + to_remove.append(t) + await t.join() + for t in to_remove: + tasks.remove(t) + await asyncio.sleep(0.05) + await pool.close() + + async def test_maxConnecting(self): + client = await self.async_rs_or_single_client() + await self.client.test.test.insert_one({}) + self.addAsyncCleanup(self.client.test.test.delete_many, {}) + pool = await async_get_pool(client) + docs = [] + + # Run 50 short running operations + async def find_one(): + docs.append(await client.test.test.find_one({})) + + tasks = [ConcurrentRunner(target=find_one) for _ in range(50)] + for task in tasks: + await task.start() + for task in tasks: + await task.join(10) + + self.assertEqual(len(docs), 50) + self.assertLessEqual(len(pool.conns), 50) + # TLS and auth make connection establishment more expensive than + # the query which leads to more threads hitting maxConnecting. + # The end result is fewer total connections and better latency. + if async_client_context.tls and async_client_context.auth_enabled: + self.assertLessEqual(len(pool.conns), 30) + else: + self.assertLessEqual(len(pool.conns), 50) + # MongoDB 4.4.1 with auth + ssl: + # maxConnecting = 2: 6 connections in ~0.231+ seconds + # maxConnecting = unbounded: 50 connections in ~0.642+ seconds + # + # MongoDB 4.4.1 with no-auth no-ssl Python 3.8: + # maxConnecting = 2: 15-22 connections in ~0.108+ seconds + # maxConnecting = unbounded: 30+ connections in ~0.140+ seconds + print(len(pool.conns)) + + @async_client_context.require_failCommand_appName + async def test_csot_timeout_message(self): + client = await self.async_rs_or_single_client(appName="connectionTimeoutApp") + # Mock an operation failing due to pymongo.timeout(). + mock_connection_timeout = { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "blockConnection": True, + "blockTimeMS": 1000, + "failCommands": ["find"], + "appName": "connectionTimeoutApp", + }, + } + + await client.db.t.insert_one({"x": 1}) + + async with self.fail_point(mock_connection_timeout): + with self.assertRaises(Exception) as error: + with timeout(0.5): + await client.db.t.find_one({"$where": delay(2)}) + + self.assertIn("(configured timeouts: timeoutMS: 500.0ms", str(error.exception)) + + @async_client_context.require_failCommand_appName + async def test_socket_timeout_message(self): + client = await self.async_rs_or_single_client( + socketTimeoutMS=500, appName="connectionTimeoutApp" + ) + # Mock an operation failing due to socketTimeoutMS. + mock_connection_timeout = { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "blockConnection": True, + "blockTimeMS": 1000, + "failCommands": ["find"], + "appName": "connectionTimeoutApp", + }, + } + + await client.db.t.insert_one({"x": 1}) + + async with self.fail_point(mock_connection_timeout): + with self.assertRaises(Exception) as error: + await client.db.t.find_one({"$where": delay(2)}) + + self.assertIn( + "(configured timeouts: socketTimeoutMS: 500.0ms, connectTimeoutMS: 20000.0ms)", + str(error.exception), + ) + + @async_client_context.require_failCommand_appName + async def test_connection_timeout_message(self): + # Mock a connection creation failing due to timeout. + mock_connection_timeout = { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "blockConnection": True, + "blockTimeMS": 1000, + "failCommands": [HelloCompat.LEGACY_CMD, "hello"], + "appName": "connectionTimeoutApp", + }, + } + + client = await self.async_rs_or_single_client( + connectTimeoutMS=500, + socketTimeoutMS=500, + appName="connectionTimeoutApp", + heartbeatFrequencyMS=1000000, + ) + await client.admin.command("ping") + pool = await async_get_pool(client) + await pool.reset_without_pause() + async with self.fail_point(mock_connection_timeout): + with self.assertRaises(Exception) as error: + await client.admin.command("ping") + + self.assertIn( + "(configured timeouts: socketTimeoutMS: 500.0ms, connectTimeoutMS: 500.0ms)", + str(error.exception), + ) + + +class TestPoolMaxSize(_TestPoolingBase): + async def test_max_pool_size(self): + max_pool_size = 4 + c = await self.async_rs_or_single_client(maxPoolSize=max_pool_size) + collection = c[DB].test + + # Need one document. + await collection.drop() + await collection.insert_one({}) + + # ntasks had better be much larger than max_pool_size to ensure that + # max_pool_size connections are actually required at some point in this + # test's execution. + cx_pool = await async_get_pool(c) + ntasks = 10 + tasks = [] + lock = _async_create_lock() + self.n_passed = 0 + + async def f(): + for _ in range(5): + await collection.find_one({"$where": delay(0.1)}) + assert len(cx_pool.conns) <= max_pool_size + + async with lock: + self.n_passed += 1 + + for _i in range(ntasks): + t = ConcurrentRunner(target=f) + tasks.append(t) + await t.start() + + await async_joinall(tasks) + self.assertEqual(ntasks, self.n_passed) + self.assertTrue(len(cx_pool.conns) > 1) + self.assertEqual(0, cx_pool.requests) + + async def test_max_pool_size_none(self): + c = await self.async_rs_or_single_client(maxPoolSize=None) + collection = c[DB].test + + # Need one document. + await collection.drop() + await collection.insert_one({}) + + cx_pool = await async_get_pool(c) + ntasks = 10 + tasks = [] + lock = _async_create_lock() + self.n_passed = 0 + + async def f(): + for _ in range(5): + await collection.find_one({"$where": delay(0.1)}) + + async with lock: + self.n_passed += 1 + + for _i in range(ntasks): + t = ConcurrentRunner(target=f) + tasks.append(t) + await t.start() + + await async_joinall(tasks) + self.assertEqual(ntasks, self.n_passed) + self.assertTrue(len(cx_pool.conns) > 1) + self.assertEqual(cx_pool.max_pool_size, float("inf")) + + async def test_max_pool_size_zero(self): + c = await self.async_rs_or_single_client(maxPoolSize=0) + pool = await async_get_pool(c) + self.assertEqual(pool.max_pool_size, float("inf")) + + async def test_max_pool_size_with_connection_failure(self): + # The pool acquires its semaphore before attempting to connect; ensure + # it releases the semaphore on connection failure. + test_pool = Pool( + ("somedomainthatdoesntexist.org", 27017), + PoolOptions(max_pool_size=1, connect_timeout=1, socket_timeout=1, wait_queue_timeout=1), + ) + await test_pool.ready() + + # First call to get_socket fails; if pool doesn't release its semaphore + # then the second call raises "ConnectionFailure: Timed out waiting for + # socket from pool" instead of AutoReconnect. + for _i in range(2): + with self.assertRaises(AutoReconnect) as context: + async with test_pool.checkout(): + pass + + # Testing for AutoReconnect instead of ConnectionFailure, above, + # is sufficient right *now* to catch a semaphore leak. But that + # seems error-prone, so check the message too. + self.assertNotIn("waiting for socket from pool", str(context.exception)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_read_concern.py b/test/asynchronous/test_read_concern.py new file mode 100644 index 0000000000..8659bf80b2 --- /dev/null +++ b/test/asynchronous/test_read_concern.py @@ -0,0 +1,122 @@ +# Copyright 2015-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the read_concern module.""" +from __future__ import annotations + +import sys +import unittest + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context +from test.utils_shared import OvertCommandListener + +from bson.son import SON +from pymongo.errors import OperationFailure +from pymongo.read_concern import ReadConcern + +_IS_SYNC = False + + +class TestReadConcern(AsyncIntegrationTest): + listener: OvertCommandListener + + @async_client_context.require_connection + async def asyncSetUp(self): + await super().asyncSetUp() + self.listener = OvertCommandListener() + self.client = await self.async_rs_or_single_client(event_listeners=[self.listener]) + self.db = self.client.pymongo_test + await async_client_context.client.pymongo_test.create_collection("coll") + + async def asyncTearDown(self): + await async_client_context.client.pymongo_test.drop_collection("coll") + + def test_read_concern(self): + rc = ReadConcern() + self.assertIsNone(rc.level) + self.assertTrue(rc.ok_for_legacy) + + rc = ReadConcern("majority") + self.assertEqual("majority", rc.level) + self.assertFalse(rc.ok_for_legacy) + + rc = ReadConcern("local") + self.assertEqual("local", rc.level) + self.assertTrue(rc.ok_for_legacy) + + self.assertRaises(TypeError, ReadConcern, 42) + + async def test_read_concern_uri(self): + uri = f"mongodb://{await async_client_context.pair}/?readConcernLevel=majority" + client = await self.async_rs_or_single_client(uri, connect=False) + self.assertEqual(ReadConcern("majority"), client.read_concern) + + async def test_invalid_read_concern(self): + coll = self.db.get_collection("coll", read_concern=ReadConcern("unknown")) + # We rely on the server to validate read concern. + with self.assertRaises(OperationFailure): + await coll.find_one() + + async def test_find_command(self): + # readConcern not sent in command if not specified. + coll = self.db.coll + await coll.find({"field": "value"}).to_list() + self.assertNotIn("readConcern", self.listener.started_events[0].command) + + self.listener.reset() + + # Explicitly set readConcern to 'local'. + coll = self.db.get_collection("coll", read_concern=ReadConcern("local")) + await coll.find({"field": "value"}).to_list() + self.assertEqualCommand( + SON( + [ + ("find", "coll"), + ("filter", {"field": "value"}), + ("readConcern", {"level": "local"}), + ] + ), + self.listener.started_events[0].command, + ) + + async def test_command_cursor(self): + # readConcern not sent in command if not specified. + coll = self.db.coll + await (await coll.aggregate([{"$match": {"field": "value"}}])).to_list() + self.assertNotIn("readConcern", self.listener.started_events[0].command) + + self.listener.reset() + + # Explicitly set readConcern to 'local'. + coll = self.db.get_collection("coll", read_concern=ReadConcern("local")) + await (await coll.aggregate([{"$match": {"field": "value"}}])).to_list() + self.assertEqual({"level": "local"}, self.listener.started_events[0].command["readConcern"]) + + async def test_aggregate_out(self): + coll = self.db.get_collection("coll", read_concern=ReadConcern("local")) + await ( + await coll.aggregate([{"$match": {"field": "value"}}, {"$out": "output_collection"}]) + ).to_list() + + # Aggregate with $out supports readConcern MongoDB 4.2 onwards. + if async_client_context.version >= (4, 1): + self.assertIn("readConcern", self.listener.started_events[0].command) + else: + self.assertNotIn("readConcern", self.listener.started_events[0].command) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_read_preferences.py b/test/asynchronous/test_read_preferences.py new file mode 100644 index 0000000000..72dd809db0 --- /dev/null +++ b/test/asynchronous/test_read_preferences.py @@ -0,0 +1,743 @@ +# Copyright 2011-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the replica_set_connection module.""" +from __future__ import annotations + +import contextlib +import copy +import pickle +import random +import sys +from typing import Any + +from pymongo.operations import _Op + +sys.path[0:0] = [""] + +from test.asynchronous import ( + AsyncIntegrationTest, + SkipTest, + async_client_context, + connected, + unittest, +) +from test.utils_shared import ( + OvertCommandListener, + _ignore_deprecations, + async_wait_until, + one, +) +from test.version import Version + +from bson.son import SON +from pymongo.asynchronous.helpers import anext +from pymongo.asynchronous.mongo_client import AsyncMongoClient +from pymongo.errors import ConfigurationError, OperationFailure +from pymongo.message import _maybe_add_read_preference +from pymongo.read_preferences import ( + MovingAverage, + Nearest, + Primary, + PrimaryPreferred, + ReadPreference, + Secondary, + SecondaryPreferred, +) +from pymongo.server_description import ServerDescription +from pymongo.server_selectors import Selection, readable_server_selector +from pymongo.server_type import SERVER_TYPE +from pymongo.write_concern import WriteConcern + +_IS_SYNC = False + + +class TestSelections(AsyncIntegrationTest): + @async_client_context.require_connection + async def test_bool(self): + client = await self.async_single_client() + + async def predicate(): + return await client.address + + await async_wait_until(predicate, "discover primary") + selection = Selection.from_topology_description(client._topology.description) + + self.assertTrue(selection) + self.assertFalse(selection.with_server_descriptions([])) + + +class TestReadPreferenceObjects(unittest.TestCase): + prefs = [ + Primary(), + PrimaryPreferred(), + Secondary(), + Nearest(tag_sets=[{"a": 1}, {"b": 2}]), + SecondaryPreferred(max_staleness=30), + ] + + def test_pickle(self): + for pref in self.prefs: + self.assertEqual(pref, pickle.loads(pickle.dumps(pref))) + + def test_copy(self): + for pref in self.prefs: + self.assertEqual(pref, copy.copy(pref)) + + def test_deepcopy(self): + for pref in self.prefs: + self.assertEqual(pref, copy.deepcopy(pref)) + + +class TestReadPreferencesBase(AsyncIntegrationTest): + @async_client_context.require_secondaries_count(1) + async def asyncSetUp(self): + await super().asyncSetUp() + # Insert some data so we can use cursors in read_from_which_host + await self.client.pymongo_test.test.drop() + await self.client.get_database( + "pymongo_test", write_concern=WriteConcern(w=async_client_context.w) + ).test.insert_many([{"_id": i} for i in range(10)]) + + self.addAsyncCleanup(self.client.pymongo_test.test.drop) + + async def read_from_which_host(self, client): + """Do a find() on the client and return which host was used""" + cursor = client.pymongo_test.test.find() + await anext(cursor) + return cursor.address + + async def read_from_which_kind(self, client): + """Do a find() on the client and return 'primary' or 'secondary' + depending on which the client used. + """ + address = await self.read_from_which_host(client) + if address == await client.primary: + return "primary" + elif address in await client.secondaries: + return "secondary" + else: + self.fail( + f"Cursor used address {address}, expected either primary " + f"{client.primary} or secondaries {client.secondaries}" + ) + + async def assertReadsFrom(self, expected, **kwargs): + c = await self.async_rs_client(**kwargs) + + async def predicate(): + return len(c.nodes - await c.arbiters) == async_client_context.w + + await async_wait_until(predicate, "discovered all nodes") + + used = await self.read_from_which_kind(c) + self.assertEqual(expected, used, f"Cursor used {used}, expected {expected}") + + +class TestSingleSecondaryOk(TestReadPreferencesBase): + async def test_reads_from_secondary(self): + host, port = next(iter(await self.client.secondaries)) + # Direct connection to a secondary. + client = await self.async_single_client(host, port) + self.assertFalse(await client.is_primary) + + # Regardless of read preference, we should be able to do + # "reads" with a direct connection to a secondary. + # See server-selection.rst#topology-type-single. + self.assertEqual(client.read_preference, ReadPreference.PRIMARY) + + db = client.pymongo_test + coll = db.test + + # Test find and find_one. + self.assertIsNotNone(await coll.find_one()) + self.assertEqual(10, len(await coll.find().to_list())) + + # Test some database helpers. + self.assertIsNotNone(await db.list_collection_names()) + self.assertIsNotNone(await db.validate_collection("test")) + self.assertIsNotNone(await db.command("ping")) + + # Test some collection helpers. + self.assertEqual(10, await coll.count_documents({})) + self.assertEqual(10, len(await coll.distinct("_id"))) + self.assertIsNotNone(await coll.aggregate([])) + self.assertIsNotNone(await coll.index_information()) + + +class TestReadPreferences(TestReadPreferencesBase): + async def test_mode_validation(self): + for mode in ( + ReadPreference.PRIMARY, + ReadPreference.PRIMARY_PREFERRED, + ReadPreference.SECONDARY, + ReadPreference.SECONDARY_PREFERRED, + ReadPreference.NEAREST, + ): + self.assertEqual( + mode, (await self.async_rs_client(read_preference=mode)).read_preference + ) + + with self.assertRaises(TypeError): + await self.async_rs_client(read_preference="foo") + + async def test_tag_sets_validation(self): + S = Secondary(tag_sets=[{}]) + self.assertEqual( + [{}], (await self.async_rs_client(read_preference=S)).read_preference.tag_sets + ) + + S = Secondary(tag_sets=[{"k": "v"}]) + self.assertEqual( + [{"k": "v"}], (await self.async_rs_client(read_preference=S)).read_preference.tag_sets + ) + + S = Secondary(tag_sets=[{"k": "v"}, {}]) + self.assertEqual( + [{"k": "v"}, {}], + (await self.async_rs_client(read_preference=S)).read_preference.tag_sets, + ) + + self.assertRaises(ValueError, Secondary, tag_sets=[]) + + # One dict not ok, must be a list of dicts + self.assertRaises(TypeError, Secondary, tag_sets={"k": "v"}) + + self.assertRaises(TypeError, Secondary, tag_sets="foo") + + self.assertRaises(TypeError, Secondary, tag_sets=["foo"]) + + async def test_threshold_validation(self): + self.assertEqual( + 17, + ( + await self.async_rs_client(localThresholdMS=17, connect=False) + ).options.local_threshold_ms, + ) + + self.assertEqual( + 42, + ( + await self.async_rs_client(localThresholdMS=42, connect=False) + ).options.local_threshold_ms, + ) + + self.assertEqual( + 666, + ( + await self.async_rs_client(localThresholdMS=666, connect=False) + ).options.local_threshold_ms, + ) + + self.assertEqual( + 0, + ( + await self.async_rs_client(localThresholdMS=0, connect=False) + ).options.local_threshold_ms, + ) + + with self.assertRaises(ValueError): + await self.async_rs_client(localthresholdms=-1) + + async def test_zero_latency(self): + ping_times: set = set() + # Generate unique ping times. + while len(ping_times) < len(self.client.nodes): + ping_times.add(random.random()) + for ping_time, host in zip(ping_times, self.client.nodes): + ServerDescription._host_to_round_trip_time[host] = ping_time + try: + client = await connected( + await self.async_rs_client(readPreference="nearest", localThresholdMS=0) + ) + await async_wait_until( + lambda: client.nodes == self.client.nodes, "discovered all nodes" + ) + host = await self.read_from_which_host(client) + for _ in range(5): + self.assertEqual(host, await self.read_from_which_host(client)) + finally: + ServerDescription._host_to_round_trip_time.clear() + + async def test_primary(self): + await self.assertReadsFrom("primary", read_preference=ReadPreference.PRIMARY) + + async def test_primary_with_tags(self): + # Tags not allowed with PRIMARY + with self.assertRaises(ConfigurationError): + await self.async_rs_client(tag_sets=[{"dc": "ny"}]) + + async def test_primary_preferred(self): + await self.assertReadsFrom("primary", read_preference=ReadPreference.PRIMARY_PREFERRED) + + async def test_secondary(self): + await self.assertReadsFrom("secondary", read_preference=ReadPreference.SECONDARY) + + async def test_secondary_preferred(self): + await self.assertReadsFrom("secondary", read_preference=ReadPreference.SECONDARY_PREFERRED) + + async def test_nearest(self): + # With high localThresholdMS, expect to read from any + # member + c = await self.async_rs_client( + read_preference=ReadPreference.NEAREST, localThresholdMS=10000 + ) # 10 seconds + + data_members = {await self.client.primary} | await self.client.secondaries + + # This is a probabilistic test; track which members we've read from so + # far, and keep reading until we've used all the members or give up. + # Chance of using only 2 of 3 members 10k times if there's no bug = + # 3 * (2/3)**10000, very low. + used: set = set() + i = 0 + while data_members.difference(used) and i < 10000: + address = await self.read_from_which_host(c) + used.add(address) + i += 1 + + not_used = data_members.difference(used) + latencies = ", ".join( + "%s: %sms" % (server.description.address, server.description.round_trip_time) + for server in await (await c._get_topology()).select_servers( + readable_server_selector, _Op.TEST + ) + ) + + self.assertFalse( + not_used, + "Expected to use primary and all secondaries for mode NEAREST," + f" but didn't use {not_used}\nlatencies: {latencies}", + ) + + +class ReadPrefTester(AsyncMongoClient): + def __init__(self, *args, **kwargs): + self.has_read_from = set() + client_options = async_client_context.client_options + client_options.update(kwargs) + super().__init__(*args, **client_options) + + async def _conn_for_reads(self, read_preference, session, operation): + context = await super()._conn_for_reads(read_preference, session, operation) + return context + + @contextlib.asynccontextmanager + async def _conn_from_server(self, read_preference, server, session): + context = super()._conn_from_server(read_preference, server, session) + async with context as (conn, read_preference): + await self.record_a_read(conn.address) + yield conn, read_preference + + async def record_a_read(self, address): + server = await (await self._get_topology()).select_server_by_address(address, _Op.TEST, 0) + self.has_read_from.add(server) + + +_PREF_MAP = [ + (Primary, SERVER_TYPE.RSPrimary), + (PrimaryPreferred, SERVER_TYPE.RSPrimary), + (Secondary, SERVER_TYPE.RSSecondary), + (SecondaryPreferred, SERVER_TYPE.RSSecondary), + (Nearest, "any"), +] + + +class TestCommandAndReadPreference(AsyncIntegrationTest): + c: ReadPrefTester + client_version: Version + + @async_client_context.require_secondaries_count(1) + async def asyncSetUp(self): + await super().asyncSetUp() + self.c = ReadPrefTester( + # Ignore round trip times, to test ReadPreference modes only. + localThresholdMS=1000 * 1000, + ) + self.client_version = await Version.async_from_client(self.c) + # mapReduce fails if the collection does not exist. + coll = self.c.pymongo_test.get_collection( + "test", write_concern=WriteConcern(w=async_client_context.w) + ) + await coll.insert_one({}) + + async def asyncTearDown(self): + await self.c.drop_database("pymongo_test") + await self.c.close() + + async def executed_on_which_server(self, client, fn, *args, **kwargs): + """Execute fn(*args, **kwargs) and return the Server instance used.""" + client.has_read_from.clear() + await fn(*args, **kwargs) + self.assertEqual(1, len(client.has_read_from)) + return one(client.has_read_from) + + async def assertExecutedOn(self, server_type, client, fn, *args, **kwargs): + server = await self.executed_on_which_server(client, fn, *args, **kwargs) + self.assertEqual( + SERVER_TYPE._fields[server_type], SERVER_TYPE._fields[server.description.server_type] + ) + + async def _test_fn(self, server_type, fn): + for _ in range(10): + if server_type == "any": + used = set() + for _ in range(1000): + server = await self.executed_on_which_server(self.c, fn) + used.add(server.description.address) + if len(used) == len(await self.c.secondaries) + 1: + # Success + break + + assert await self.c.primary is not None + unused = (await self.c.secondaries).union({await self.c.primary}).difference(used) + if unused: + self.fail("Some members not used for NEAREST: %s" % (unused)) + else: + await self.assertExecutedOn(server_type, self.c, fn) + + async def _test_primary_helper(self, func): + # Helpers that ignore read preference. + await self._test_fn(SERVER_TYPE.RSPrimary, func) + + async def _test_coll_helper(self, secondary_ok, coll, meth, *args, **kwargs): + for mode, server_type in _PREF_MAP: + new_coll = coll.with_options(read_preference=mode()) + + async def func(): + return await getattr(new_coll, meth)(*args, **kwargs) + + if secondary_ok: + await self._test_fn(server_type, func) + else: + await self._test_fn(SERVER_TYPE.RSPrimary, func) + + async def test_command(self): + # Test that the generic command helper obeys the read preference + # passed to it. + for mode, server_type in _PREF_MAP: + + async def func(): + return await self.c.pymongo_test.command("dbStats", read_preference=mode()) + + await self._test_fn(server_type, func) + + async def test_create_collection(self): + # create_collection runs listCollections on the primary to check if + # the collection already exists. + async def func(): + return await self.c.pymongo_test.create_collection( + "some_collection%s" % random.randint(0, sys.maxsize) + ) + + await self._test_primary_helper(func) + + async def test_count_documents(self): + await self._test_coll_helper(True, self.c.pymongo_test.test, "count_documents", {}) + + async def test_estimated_document_count(self): + await self._test_coll_helper(True, self.c.pymongo_test.test, "estimated_document_count") + + async def test_distinct(self): + await self._test_coll_helper(True, self.c.pymongo_test.test, "distinct", "a") + + async def test_aggregate(self): + await self._test_coll_helper( + True, self.c.pymongo_test.test, "aggregate", [{"$project": {"_id": 1}}] + ) + + async def test_aggregate_write(self): + # 5.0 servers support $out on secondaries. + secondary_ok = async_client_context.version.at_least(5, 0) + await self._test_coll_helper( + secondary_ok, + self.c.pymongo_test.test, + "aggregate", + [{"$project": {"_id": 1}}, {"$out": "agg_write_test"}], + ) + + +class TestMovingAverage(unittest.TestCase): + def test_moving_average(self): + avg = MovingAverage() + self.assertIsNone(avg.get()) + avg.add_sample(10) + self.assertAlmostEqual(10, avg.get()) # type: ignore + avg.add_sample(20) + self.assertAlmostEqual(12, avg.get()) # type: ignore + avg.add_sample(30) + self.assertAlmostEqual(15.6, avg.get()) # type: ignore + + +class TestMongosAndReadPreference(AsyncIntegrationTest): + def test_read_preference_document(self): + pref = Primary() + self.assertEqual(pref.document, {"mode": "primary"}) + + pref = PrimaryPreferred() + self.assertEqual(pref.document, {"mode": "primaryPreferred"}) + pref = PrimaryPreferred(tag_sets=[{"dc": "sf"}]) + self.assertEqual(pref.document, {"mode": "primaryPreferred", "tags": [{"dc": "sf"}]}) + pref = PrimaryPreferred(tag_sets=[{"dc": "sf"}], max_staleness=30) + self.assertEqual( + pref.document, + {"mode": "primaryPreferred", "tags": [{"dc": "sf"}], "maxStalenessSeconds": 30}, + ) + + pref = Secondary() + self.assertEqual(pref.document, {"mode": "secondary"}) + pref = Secondary(tag_sets=[{"dc": "sf"}]) + self.assertEqual(pref.document, {"mode": "secondary", "tags": [{"dc": "sf"}]}) + pref = Secondary(tag_sets=[{"dc": "sf"}], max_staleness=30) + self.assertEqual( + pref.document, {"mode": "secondary", "tags": [{"dc": "sf"}], "maxStalenessSeconds": 30} + ) + + pref = SecondaryPreferred() + self.assertEqual(pref.document, {"mode": "secondaryPreferred"}) + pref = SecondaryPreferred(tag_sets=[{"dc": "sf"}]) + self.assertEqual(pref.document, {"mode": "secondaryPreferred", "tags": [{"dc": "sf"}]}) + pref = SecondaryPreferred(tag_sets=[{"dc": "sf"}], max_staleness=30) + self.assertEqual( + pref.document, + {"mode": "secondaryPreferred", "tags": [{"dc": "sf"}], "maxStalenessSeconds": 30}, + ) + + pref = Nearest() + self.assertEqual(pref.document, {"mode": "nearest"}) + pref = Nearest(tag_sets=[{"dc": "sf"}]) + self.assertEqual(pref.document, {"mode": "nearest", "tags": [{"dc": "sf"}]}) + pref = Nearest(tag_sets=[{"dc": "sf"}], max_staleness=30) + self.assertEqual( + pref.document, {"mode": "nearest", "tags": [{"dc": "sf"}], "maxStalenessSeconds": 30} + ) + + with self.assertRaises(TypeError): + # Float is prohibited. + Nearest(max_staleness=1.5) # type: ignore + + with self.assertRaises(ValueError): + Nearest(max_staleness=0) + + with self.assertRaises(ValueError): + Nearest(max_staleness=-2) + + def test_read_preference_document_hedge(self): + cases = { + "primaryPreferred": PrimaryPreferred, + "secondary": Secondary, + "secondaryPreferred": SecondaryPreferred, + "nearest": Nearest, + } + for mode, cls in cases.items(): + with self.assertRaises(TypeError): + cls(hedge=[]) # type: ignore + with _ignore_deprecations(): + pref = cls(hedge={}) + self.assertEqual(pref.document, {"mode": mode}) + out = _maybe_add_read_preference({}, pref) + if cls == SecondaryPreferred: + # SecondaryPreferred without hedge doesn't add $readPreference. + self.assertEqual(out, {}) + else: + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + + hedge: dict[str, Any] = {"enabled": True} + pref = cls(hedge=hedge) + self.assertEqual(pref.document, {"mode": mode, "hedge": hedge}) + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + + hedge = {"enabled": False} + pref = cls(hedge=hedge) + self.assertEqual(pref.document, {"mode": mode, "hedge": hedge}) + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + + hedge = {"enabled": False, "extra": "option"} + pref = cls(hedge=hedge) + self.assertEqual(pref.document, {"mode": mode, "hedge": hedge}) + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + + def test_read_preference_hedge_deprecated(self): + cases = { + "primaryPreferred": PrimaryPreferred, + "secondary": Secondary, + "secondaryPreferred": SecondaryPreferred, + "nearest": Nearest, + } + for _, cls in cases.items(): + with self.assertRaises(DeprecationWarning): + cls(hedge={"enabled": True}) + + async def test_send_hedge(self): + cases = { + "primaryPreferred": PrimaryPreferred, + "secondaryPreferred": SecondaryPreferred, + "nearest": Nearest, + } + if await async_client_context.supports_secondary_read_pref: + cases["secondary"] = Secondary + listener = OvertCommandListener() + client = await self.async_rs_client(event_listeners=[listener]) + await client.admin.command("ping") + for _mode, cls in cases.items(): + with _ignore_deprecations(): + pref = cls(hedge={"enabled": True}) + coll = client.test.get_collection("test", read_preference=pref) + listener.reset() + await coll.find_one() + started = listener.started_events + self.assertEqual(len(started), 1, started) + cmd = started[0].command + if async_client_context.is_rs or async_client_context.is_mongos: + self.assertIn("$readPreference", cmd) + self.assertEqual(cmd["$readPreference"], pref.document) + else: + self.assertNotIn("$readPreference", cmd) + + def test_maybe_add_read_preference(self): + # Primary doesn't add $readPreference + out = _maybe_add_read_preference({}, Primary()) + self.assertEqual(out, {}) + + pref = PrimaryPreferred() + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + pref = PrimaryPreferred(tag_sets=[{"dc": "nyc"}]) + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + + pref = Secondary() + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + pref = Secondary(tag_sets=[{"dc": "nyc"}]) + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + + # SecondaryPreferred without tag_sets or max_staleness doesn't add + # $readPreference + pref = SecondaryPreferred() + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, {}) + pref = SecondaryPreferred(tag_sets=[{"dc": "nyc"}]) + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + pref = SecondaryPreferred(max_staleness=120) + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + + pref = Nearest() + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + pref = Nearest(tag_sets=[{"dc": "nyc"}]) + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + + criteria = SON([("$query", {}), ("$orderby", SON([("_id", 1)]))]) + pref = Nearest() + out = _maybe_add_read_preference(criteria, pref) + self.assertEqual( + out, + SON( + [ + ("$query", {}), + ("$orderby", SON([("_id", 1)])), + ("$readPreference", pref.document), + ] + ), + ) + pref = Nearest(tag_sets=[{"dc": "nyc"}]) + out = _maybe_add_read_preference(criteria, pref) + self.assertEqual( + out, + SON( + [ + ("$query", {}), + ("$orderby", SON([("_id", 1)])), + ("$readPreference", pref.document), + ] + ), + ) + + @async_client_context.require_mongos + async def test_mongos(self): + res = await async_client_context.client.config.shards.find_one() + assert res is not None + shard = res["host"] + num_members = shard.count(",") + 1 + if num_members == 1: + raise SkipTest("Need a replica set shard to test.") + coll = async_client_context.client.pymongo_test.get_collection( + "test", write_concern=WriteConcern(w=num_members) + ) + await coll.drop() + res = await coll.insert_many([{} for _ in range(5)]) + first_id = res.inserted_ids[0] + last_id = res.inserted_ids[-1] + + # Note - this isn't a perfect test since there's no way to + # tell what shard member a query ran on. + for pref in (Primary(), PrimaryPreferred(), Secondary(), SecondaryPreferred(), Nearest()): + qcoll = coll.with_options(read_preference=pref) + results = await qcoll.find().sort([("_id", 1)]).to_list() + self.assertEqual(first_id, results[0]["_id"]) + self.assertEqual(last_id, results[-1]["_id"]) + results = await qcoll.find().sort([("_id", -1)]).to_list() + self.assertEqual(first_id, results[-1]["_id"]) + self.assertEqual(last_id, results[0]["_id"]) + + @async_client_context.require_mongos + async def test_mongos_max_staleness(self): + # Sanity check that we're sending maxStalenessSeconds + coll = async_client_context.client.pymongo_test.get_collection( + "test", read_preference=SecondaryPreferred(max_staleness=120) + ) + # No error + await coll.find_one() + + coll = async_client_context.client.pymongo_test.get_collection( + "test", read_preference=SecondaryPreferred(max_staleness=10) + ) + try: + await coll.find_one() + except OperationFailure as exc: + self.assertEqual(160, exc.code) + else: + self.fail("mongos accepted invalid staleness") + + coll = ( + await self.async_single_client( + readPreference="secondaryPreferred", maxStalenessSeconds=120 + ) + ).pymongo_test.test + # No error + await coll.find_one() + + coll = ( + await self.async_single_client( + readPreference="secondaryPreferred", maxStalenessSeconds=10 + ) + ).pymongo_test.test + try: + await coll.find_one() + except OperationFailure as exc: + self.assertEqual(160, exc.code) + else: + self.fail("mongos accepted invalid staleness") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_read_write_concern_spec.py b/test/asynchronous/test_read_write_concern_spec.py new file mode 100644 index 0000000000..86f79fd28d --- /dev/null +++ b/test/asynchronous/test_read_write_concern_spec.py @@ -0,0 +1,344 @@ +# Copyright 2018-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run the read and write concern tests.""" +from __future__ import annotations + +import json +import os +import sys +import warnings +from pathlib import Path + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.asynchronous.unified_format import generate_test_classes +from test.utils_shared import OvertCommandListener + +from pymongo import DESCENDING +from pymongo.asynchronous.mongo_client import AsyncMongoClient +from pymongo.errors import ( + BulkWriteError, + ConfigurationError, + WriteConcernError, + WriteError, + WTimeoutError, +) +from pymongo.operations import IndexModel, InsertOne +from pymongo.read_concern import ReadConcern +from pymongo.write_concern import WriteConcern + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "read_write_concern") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "read_write_concern") + + +class TestReadWriteConcernSpec(AsyncIntegrationTest): + async def test_omit_default_read_write_concern(self): + listener = OvertCommandListener() + # Client with default readConcern and writeConcern + client = await self.async_rs_or_single_client(event_listeners=[listener]) + collection = client.pymongo_test.collection + # Prepare for tests of find() and aggregate(). + await collection.insert_many([{} for _ in range(10)]) + self.addAsyncCleanup(collection.drop) + self.addAsyncCleanup(client.pymongo_test.collection2.drop) + # Commands MUST NOT send the default read/write concern to the server. + + async def rename_and_drop(): + # Ensure collection exists. + await collection.insert_one({}) + await collection.rename("collection2") + await client.pymongo_test.collection2.drop() + + async def insert_command_default_write_concern(): + await collection.database.command( + "insert", "collection", documents=[{}], write_concern=WriteConcern() + ) + + async def aggregate_op(): + await (await collection.aggregate([])).to_list() + + ops = [ + ("aggregate", aggregate_op), + ("find", lambda: collection.find().to_list()), + ("insert_one", lambda: collection.insert_one({})), + ("update_one", lambda: collection.update_one({}, {"$set": {"x": 1}})), + ("update_many", lambda: collection.update_many({}, {"$set": {"x": 1}})), + ("delete_one", lambda: collection.delete_one({})), + ("delete_many", lambda: collection.delete_many({})), + ("bulk_write", lambda: collection.bulk_write([InsertOne({})])), + ("rename_and_drop", rename_and_drop), + ("command", insert_command_default_write_concern), + ] + + for name, f in ops: + listener.reset() + await f() + + self.assertGreaterEqual(len(listener.started_events), 1) + for _i, event in enumerate(listener.started_events): + self.assertNotIn( + "readConcern", + event.command, + f"{name} sent default readConcern with {event.command_name}", + ) + self.assertNotIn( + "writeConcern", + event.command, + f"{name} sent default writeConcern with {event.command_name}", + ) + + async def assertWriteOpsRaise(self, write_concern, expected_exception): + wc = write_concern.document + # Set socket timeout to avoid indefinite stalls + client = await self.async_rs_or_single_client( + w=wc["w"], wTimeoutMS=wc["wtimeout"], socketTimeoutMS=30000 + ) + db = client.get_database("pymongo_test") + coll = db.test + + async def insert_command(): + await coll.database.command( + "insert", + "new_collection", + documents=[{}], + writeConcern=write_concern.document, + parse_write_concern_error=True, + ) + + ops = [ + ("insert_one", lambda: coll.insert_one({})), + ("insert_many", lambda: coll.insert_many([{}, {}])), + ("update_one", lambda: coll.update_one({}, {"$set": {"x": 1}})), + ("update_many", lambda: coll.update_many({}, {"$set": {"x": 1}})), + ("delete_one", lambda: coll.delete_one({})), + ("delete_many", lambda: coll.delete_many({})), + ("bulk_write", lambda: coll.bulk_write([InsertOne({})])), + ("command", insert_command), + ("aggregate", lambda: coll.aggregate([{"$out": "out"}])), + # SERVER-46668 Delete all the documents in the collection to + # workaround a hang in createIndexes. + ("delete_many", lambda: coll.delete_many({})), + ("create_index", lambda: coll.create_index([("a", DESCENDING)])), + ("create_indexes", lambda: coll.create_indexes([IndexModel("b")])), + ("drop_index", lambda: coll.drop_index([("a", DESCENDING)])), + ("create", lambda: db.create_collection("new")), + ("rename", lambda: coll.rename("new")), + ("drop", lambda: db.new.drop()), + ] + # SERVER-47194: dropDatabase does not respect wtimeout in 3.6. + if async_client_context.version[:2] != (3, 6): + ops.append(("drop_database", lambda: client.drop_database(db))) + + for name, f in ops: + # Ensure insert_many and bulk_write still raise BulkWriteError. + if name in ("insert_many", "bulk_write"): + expected = BulkWriteError + else: + expected = expected_exception + with self.assertRaises(expected, msg=name) as cm: + await f() + if expected == BulkWriteError: + bulk_result = cm.exception.details + assert bulk_result is not None + wc_errors = bulk_result["writeConcernErrors"] + self.assertTrue(wc_errors) + + @async_client_context.require_replica_set + async def test_raise_write_concern_error(self): + self.addAsyncCleanup(async_client_context.client.drop_database, "pymongo_test") + assert async_client_context.w is not None + await self.assertWriteOpsRaise( + WriteConcern(w=async_client_context.w + 1, wtimeout=1), WriteConcernError + ) + + @async_client_context.require_secondaries_count(1) + @async_client_context.require_test_commands + async def test_raise_wtimeout(self): + self.addAsyncCleanup(async_client_context.client.drop_database, "pymongo_test") + self.addAsyncCleanup(self.enable_replication, async_client_context.client) + # Disable replication to guarantee a wtimeout error. + await self.disable_replication(async_client_context.client) + await self.assertWriteOpsRaise( + WriteConcern(w=async_client_context.w, wtimeout=1), WTimeoutError + ) + + @async_client_context.require_failCommand_fail_point + async def test_error_includes_errInfo(self): + expected_wce = { + "code": 100, + "codeName": "UnsatisfiableWriteConcern", + "errmsg": "Not enough data-bearing nodes", + "errInfo": {"writeConcern": {"w": 2, "wtimeout": 0, "provenance": "clientSupplied"}}, + } + cause_wce = { + "configureFailPoint": "failCommand", + "mode": {"times": 2}, + "data": {"failCommands": ["insert"], "writeConcernError": expected_wce}, + } + async with self.fail_point(cause_wce): + # Write concern error on insert includes errInfo. + with self.assertRaises(WriteConcernError) as ctx: + await self.db.test.insert_one({}) + self.assertEqual(ctx.exception.details, expected_wce) + + # Test bulk_write as well. + with self.assertRaises(BulkWriteError) as ctx: + await self.db.test.bulk_write([InsertOne({})]) + expected_details = { + "writeErrors": [], + "writeConcernErrors": [expected_wce], + "nInserted": 1, + "nUpserted": 0, + "nMatched": 0, + "nModified": 0, + "nRemoved": 0, + "upserted": [], + } + self.assertEqual(ctx.exception.details, expected_details) + + @async_client_context.require_version_min(4, 9) + async def test_write_error_details_exposes_errinfo(self): + listener = OvertCommandListener() + client = await self.async_rs_or_single_client(event_listeners=[listener]) + db = client.errinfotest + self.addAsyncCleanup(client.drop_database, "errinfotest") + validator = {"x": {"$type": "string"}} + await db.create_collection("test", validator=validator) + with self.assertRaises(WriteError) as ctx: + await db.test.insert_one({"x": 1}) + self.assertEqual(ctx.exception.code, 121) + self.assertIsNotNone(ctx.exception.details) + assert ctx.exception.details is not None + self.assertIsNotNone(ctx.exception.details.get("errInfo")) + for event in listener.succeeded_events: + if event.command_name == "insert": + self.assertEqual(event.reply["writeErrors"][0], ctx.exception.details) + break + else: + self.fail("Couldn't find insert event.") + + +def normalize_write_concern(concern): + result = {} + for key in concern: + if key.lower() == "wtimeoutms": + result["wtimeout"] = concern[key] + elif key == "journal": + result["j"] = concern[key] + else: + result[key] = concern[key] + return result + + +def create_connection_string_test(test_case): + def run_test(self): + uri = test_case["uri"] + valid = test_case["valid"] + warning = test_case["warning"] + + if not valid: + if warning is False: + self.assertRaises( + (ConfigurationError, ValueError), AsyncMongoClient, uri, connect=False + ) + else: + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + self.assertRaises(UserWarning, AsyncMongoClient, uri, connect=False) + else: + client = AsyncMongoClient(uri, connect=False) + if "writeConcern" in test_case: + document = client.write_concern.document + self.assertEqual(document, normalize_write_concern(test_case["writeConcern"])) + if "readConcern" in test_case: + document = client.read_concern.document + self.assertEqual(document, test_case["readConcern"]) + + return run_test + + +def create_document_test(test_case): + def run_test(self): + valid = test_case["valid"] + + if "writeConcern" in test_case: + normalized = normalize_write_concern(test_case["writeConcern"]) + if not valid: + self.assertRaises((ConfigurationError, ValueError), WriteConcern, **normalized) + else: + write_concern = WriteConcern(**normalized) + self.assertEqual(write_concern.document, test_case["writeConcernDocument"]) + self.assertEqual(write_concern.acknowledged, test_case["isAcknowledged"]) + self.assertEqual(write_concern.is_server_default, test_case["isServerDefault"]) + if "readConcern" in test_case: + # Any string for 'level' is equally valid + read_concern = ReadConcern(**test_case["readConcern"]) + self.assertEqual(read_concern.document, test_case["readConcernDocument"]) + self.assertEqual(not bool(read_concern.level), test_case["isServerDefault"]) + + return run_test + + +def create_tests(): + for dirpath, _, filenames in os.walk(TEST_PATH): + dirname = os.path.split(dirpath)[-1] + + if dirname == "operation": + # This directory is tested by TestOperations. + continue + elif dirname == "connection-string": + create_test = create_connection_string_test + else: + create_test = create_document_test + + for filename in filenames: + with open(os.path.join(dirpath, filename)) as test_stream: + test_cases = json.load(test_stream)["tests"] + + fname = os.path.splitext(filename)[0] + for test_case in test_cases: + new_test = create_test(test_case) + test_name = "test_{}_{}_{}".format( + dirname.replace("-", "_"), + fname.replace("-", "_"), + str(test_case["description"].lower().replace(" ", "_")), + ) + + new_test.__name__ = test_name + setattr(TestReadWriteConcernSpec, new_test.__name__, new_test) + + +create_tests() + + +# Generate unified tests. +# PyMongo does not support MapReduce. +globals().update( + generate_test_classes( + os.path.join(TEST_PATH, "operation"), + module=__name__, + expected_failures=["MapReduce .*"], + ) +) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_retryable_reads.py b/test/asynchronous/test_retryable_reads.py index bde7a9f2ee..10d9e738b4 100644 --- a/test/asynchronous/test_retryable_reads.py +++ b/test/asynchronous/test_retryable_reads.py @@ -19,6 +19,7 @@ import pprint import sys import threading +from test.asynchronous.utils import async_set_fail_point from pymongo.errors import AutoReconnect @@ -31,10 +32,9 @@ client_knobs, unittest, ) -from test.utils import ( +from test.utils_shared import ( CMAPListener, OvertCommandListener, - async_set_fail_point, ) from pymongo.monitoring import ( diff --git a/test/asynchronous/test_retryable_reads_unified.py b/test/asynchronous/test_retryable_reads_unified.py new file mode 100644 index 0000000000..e62d606810 --- /dev/null +++ b/test/asynchronous/test_retryable_reads_unified.py @@ -0,0 +1,46 @@ +# Copyright 2022-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the Retryable Reads unified spec tests.""" +from __future__ import annotations + +import os +import sys +from pathlib import Path + +sys.path[0:0] = [""] + +from test import unittest +from test.asynchronous.unified_format import generate_test_classes + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "retryable_reads/unified") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "retryable_reads/unified") + +# Generate unified tests. +# PyMongo does not support MapReduce, ListDatabaseObjects or ListCollectionObjects. +globals().update( + generate_test_classes( + TEST_PATH, + module=__name__, + expected_failures=["ListDatabaseObjects .*", "ListCollectionObjects .*", "MapReduce .*"], + ) +) + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_retryable_writes.py b/test/asynchronous/test_retryable_writes.py index 738ce04192..842233a3ef 100644 --- a/test/asynchronous/test_retryable_writes.py +++ b/test/asynchronous/test_retryable_writes.py @@ -20,6 +20,7 @@ import pprint import sys import threading +from test.asynchronous.utils import async_set_fail_point sys.path[0:0] = [""] @@ -30,12 +31,11 @@ unittest, ) from test.asynchronous.helpers import client_knobs -from test.utils import ( +from test.utils_shared import ( CMAPListener, DeprecationFilter, EventListener, OvertCommandListener, - async_set_fail_point, ) from test.version import Version @@ -137,6 +137,7 @@ async def asyncSetUp(self) -> None: self.deprecation_filter = DeprecationFilter() async def asyncTearDown(self) -> None: + await super().asyncTearDown() self.deprecation_filter.stop() @@ -196,6 +197,7 @@ async def asyncTearDown(self): SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")]) ) self.knobs.disable() + await super().asyncTearDown() async def test_supported_single_statement_no_retry(self): listener = OvertCommandListener() diff --git a/test/asynchronous/test_retryable_writes_unified.py b/test/asynchronous/test_retryable_writes_unified.py new file mode 100644 index 0000000000..bb493e6010 --- /dev/null +++ b/test/asynchronous/test_retryable_writes_unified.py @@ -0,0 +1,39 @@ +# Copyright 2021-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the Retryable Writes unified spec tests.""" +from __future__ import annotations + +import os +import sys +from pathlib import Path + +sys.path[0:0] = [""] + +from test import unittest +from test.asynchronous.unified_format import generate_test_classes + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "retryable_writes/unified") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "retryable_writes/unified") + +# Generate unified tests. +globals().update(generate_test_classes(TEST_PATH, module=__name__)) + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_run_command.py b/test/asynchronous/test_run_command.py new file mode 100644 index 0000000000..3ac8c32706 --- /dev/null +++ b/test/asynchronous/test_run_command.py @@ -0,0 +1,41 @@ +# Copyright 2024-Present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run Command unified tests.""" +from __future__ import annotations + +import os +import unittest +from pathlib import Path +from test.asynchronous.unified_format import generate_test_classes + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "run_command") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "run_command") + + +globals().update( + generate_test_classes( + os.path.join(TEST_PATH, "unified"), + module=__name__, + ) +) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_sdam_monitoring_spec.py b/test/asynchronous/test_sdam_monitoring_spec.py new file mode 100644 index 0000000000..71ec6c6b46 --- /dev/null +++ b/test/asynchronous/test_sdam_monitoring_spec.py @@ -0,0 +1,374 @@ +# Copyright 2016 MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run the sdam monitoring spec tests.""" +from __future__ import annotations + +import asyncio +import json +import os +import sys +import time +from pathlib import Path + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs, unittest +from test.utils_shared import ( + ServerAndTopologyEventListener, + async_wait_until, + server_name_to_type, +) + +from bson.json_util import object_hook +from pymongo import AsyncMongoClient, monitoring +from pymongo.asynchronous.collection import AsyncCollection +from pymongo.asynchronous.monitor import Monitor +from pymongo.common import clean_node +from pymongo.errors import ConnectionFailure, NotPrimaryError +from pymongo.hello import Hello +from pymongo.server_description import ServerDescription +from pymongo.topology_description import TOPOLOGY_TYPE + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "sdam_monitoring") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "sdam_monitoring") + + +def compare_server_descriptions(expected, actual): + if (expected["address"] != "{}:{}".format(*actual.address)) or ( + server_name_to_type(expected["type"]) != actual.server_type + ): + return False + expected_hosts = set(expected["arbiters"] + expected["passives"] + expected["hosts"]) + return expected_hosts == {"{}:{}".format(*s) for s in actual.all_hosts} + + +def compare_topology_descriptions(expected, actual): + if TOPOLOGY_TYPE.__getattribute__(expected["topologyType"]) != actual.topology_type: + return False + expected = expected["servers"] + actual = actual.server_descriptions() + if len(expected) != len(actual): + return False + for exp_server in expected: + for _address, actual_server in actual.items(): + if compare_server_descriptions(exp_server, actual_server): + break + else: + return False + return True + + +def compare_events(expected_dict, actual): + if not expected_dict: + return False, "Error: Bad expected value in YAML test" + if not actual: + return False, "Error: Event published was None" + + expected_type, expected = list(expected_dict.items())[0] + + if expected_type == "server_opening_event": + if not isinstance(actual, monitoring.ServerOpeningEvent): + return False, "Expected ServerOpeningEvent, got %s" % (actual.__class__) + if expected["address"] != "{}:{}".format(*actual.server_address): + return ( + False, + "ServerOpeningEvent published with wrong address (expected" " {}, got {}".format( + expected["address"], actual.server_address + ), + ) + + elif expected_type == "server_description_changed_event": + if not isinstance(actual, monitoring.ServerDescriptionChangedEvent): + return (False, "Expected ServerDescriptionChangedEvent, got %s" % (actual.__class__)) + if expected["address"] != "{}:{}".format(*actual.server_address): + return ( + False, + "ServerDescriptionChangedEvent has wrong address" " (expected {}, got {}".format( + expected["address"], actual.server_address + ), + ) + + if not compare_server_descriptions(expected["newDescription"], actual.new_description): + return (False, "New ServerDescription incorrect in ServerDescriptionChangedEvent") + if not compare_server_descriptions( + expected["previousDescription"], actual.previous_description + ): + return ( + False, + "Previous ServerDescription incorrect in ServerDescriptionChangedEvent", + ) + + elif expected_type == "server_closed_event": + if not isinstance(actual, monitoring.ServerClosedEvent): + return False, "Expected ServerClosedEvent, got %s" % (actual.__class__) + if expected["address"] != "{}:{}".format(*actual.server_address): + return ( + False, + "ServerClosedEvent published with wrong address" " (expected {}, got {}".format( + expected["address"], actual.server_address + ), + ) + + elif expected_type == "topology_opening_event": + if not isinstance(actual, monitoring.TopologyOpenedEvent): + return False, "Expected TopologyOpenedEvent, got %s" % (actual.__class__) + + elif expected_type == "topology_description_changed_event": + if not isinstance(actual, monitoring.TopologyDescriptionChangedEvent): + return ( + False, + "Expected TopologyDescriptionChangedEvent, got %s" % (actual.__class__), + ) + if not compare_topology_descriptions(expected["newDescription"], actual.new_description): + return ( + False, + "New TopologyDescription incorrect in TopologyDescriptionChangedEvent", + ) + if not compare_topology_descriptions( + expected["previousDescription"], actual.previous_description + ): + return ( + False, + "Previous TopologyDescription incorrect in TopologyDescriptionChangedEvent", + ) + + elif expected_type == "topology_await aclosed_event": + if not isinstance(actual, monitoring.TopologyClosedEvent): + return False, "Expected TopologyClosedEvent, got %s" % (actual.__class__) + + else: + return False, f"Incorrect event: expected {expected_type}, actual {actual}" + + return True, "" + + +def compare_multiple_events(i, expected_results, actual_results): + events_in_a_row = [] + j = i + while j < len(expected_results) and isinstance(actual_results[j], actual_results[i].__class__): + events_in_a_row.append(actual_results[j]) + j += 1 + message = "" + for event in events_in_a_row: + for k in range(i, j): + passed, message = compare_events(expected_results[k], event) + if passed: + expected_results[k] = None + break + else: + return i, False, message + return j, True, "" + + +class TestAllScenarios(AsyncIntegrationTest): + async def asyncSetUp(self): + await super().asyncSetUp() + self.all_listener = ServerAndTopologyEventListener() + + +def create_test(scenario_def): + async def run_scenario(self): + with client_knobs(events_queue_frequency=0.05, min_heartbeat_interval=0.05): + await _run_scenario(self) + + async def _run_scenario(self): + class NoopMonitor(Monitor): + """Override the _run method to do nothing.""" + + async def _run(self): + await asyncio.sleep(0.05) + + m = AsyncMongoClient( + host=scenario_def["uri"], + port=27017, + event_listeners=[self.all_listener], + _monitor_class=NoopMonitor, + ) + topology = await m._get_topology() + + try: + for phase in scenario_def["phases"]: + for source, response in phase.get("responses", []): + source_address = clean_node(source) + await topology.on_change( + ServerDescription( + address=source_address, hello=Hello(response), round_trip_time=0 + ) + ) + + expected_results = phase["outcome"]["events"] + expected_len = len(expected_results) + await async_wait_until( + lambda: len(self.all_listener.results) >= expected_len, + "publish all events", + timeout=15, + ) + + # Wait some time to catch possible lagging extra events. + await async_wait_until(lambda: topology._events.empty(), "publish lagging events") + + i = 0 + while i < expected_len: + result = ( + self.all_listener.results[i] if len(self.all_listener.results) > i else None + ) + # The order of ServerOpening/ClosedEvents doesn't matter + if isinstance( + result, (monitoring.ServerOpeningEvent, monitoring.ServerClosedEvent) + ): + i, passed, message = compare_multiple_events( + i, expected_results, self.all_listener.results + ) + self.assertTrue(passed, message) + else: + self.assertTrue(*compare_events(expected_results[i], result)) + i += 1 + + # Assert no extra events. + extra_events = self.all_listener.results[expected_len:] + if extra_events: + self.fail(f"Extra events {extra_events!r}") + + self.all_listener.reset() + finally: + await m.close() + + return run_scenario + + +def create_tests(): + for dirpath, _, filenames in os.walk(TEST_PATH): + for filename in filenames: + with open(os.path.join(dirpath, filename)) as scenario_stream: + scenario_def = json.load(scenario_stream, object_hook=object_hook) + # Construct test from scenario. + new_test = create_test(scenario_def) + test_name = f"test_{os.path.splitext(filename)[0]}" + new_test.__name__ = test_name + setattr(TestAllScenarios, new_test.__name__, new_test) + + +create_tests() + + +class TestSdamMonitoring(AsyncIntegrationTest): + knobs: client_knobs + listener: ServerAndTopologyEventListener + test_client: AsyncMongoClient + coll: AsyncCollection + + @classmethod + def setUpClass(cls): + # Speed up the tests by decreasing the event publish frequency. + cls.knobs = client_knobs( + events_queue_frequency=0.1, heartbeat_frequency=0.1, min_heartbeat_interval=0.1 + ) + cls.knobs.enable() + cls.listener = ServerAndTopologyEventListener() + + @classmethod + def tearDownClass(cls): + cls.knobs.disable() + + @async_client_context.require_failCommand_fail_point + async def asyncSetUp(self): + await super().asyncSetUp() + + retry_writes = async_client_context.supports_transactions() + self.test_client = await self.async_rs_or_single_client( + event_listeners=[self.listener], retryWrites=retry_writes + ) + self.coll = self.test_client[self.client.db.name].test + await self.coll.insert_one({}) + self.listener.reset() + + async def asyncTearDown(self): + await super().asyncTearDown() + + async def _test_app_error(self, fail_command_opts, expected_error): + address = await self.test_client.address + + # Test that an application error causes a ServerDescriptionChangedEvent + # to be published. + data = {"failCommands": ["insert"]} + data.update(fail_command_opts) + fail_insert = { + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": data, + } + async with self.fail_point(fail_insert): + if self.test_client.options.retry_writes: + await self.coll.insert_one({}) + else: + with self.assertRaises(expected_error): + await self.coll.insert_one({}) + await self.coll.insert_one({}) + + def marked_unknown(event): + return ( + isinstance(event, monitoring.ServerDescriptionChangedEvent) + and event.server_address == address + and not event.new_description.is_server_type_known + ) + + def discovered_node(event): + return ( + isinstance(event, monitoring.ServerDescriptionChangedEvent) + and event.server_address == address + and not event.previous_description.is_server_type_known + and event.new_description.is_server_type_known + ) + + def marked_unknown_and_rediscovered(): + return ( + len(self.listener.matching(marked_unknown)) >= 1 + and len(self.listener.matching(discovered_node)) >= 1 + ) + + # Topology events are not published synchronously + await async_wait_until(marked_unknown_and_rediscovered, "rediscover node") + + # Expect a single ServerDescriptionChangedEvent for the network error. + marked_unknown_events = self.listener.matching(marked_unknown) + self.assertEqual(len(marked_unknown_events), 1, marked_unknown_events) + self.assertIsInstance(marked_unknown_events[0].new_description.error, expected_error) + + async def test_network_error_publishes_events(self): + await self._test_app_error({"closeConnection": True}, ConnectionFailure) + + # In 4.4+, not primary errors from failCommand don't cause SDAM state + # changes because topologyVersion is not incremented. + @async_client_context.require_version_max(4, 3) + async def test_not_primary_error_publishes_events(self): + await self._test_app_error( + {"errorCode": 10107, "closeConnection": False, "errorLabels": ["RetryableWriteError"]}, + NotPrimaryError, + ) + + async def test_shutdown_error_publishes_events(self): + await self._test_app_error( + {"errorCode": 91, "closeConnection": False, "errorLabels": ["RetryableWriteError"]}, + NotPrimaryError, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_server_selection.py b/test/asynchronous/test_server_selection.py new file mode 100644 index 0000000000..f98a05ee91 --- /dev/null +++ b/test/asynchronous/test_server_selection.py @@ -0,0 +1,212 @@ +# Copyright 2015-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the topology module's Server Selection Spec implementation.""" +from __future__ import annotations + +import os +import sys +from pathlib import Path + +from pymongo import AsyncMongoClient, ReadPreference +from pymongo.asynchronous.settings import TopologySettings +from pymongo.asynchronous.topology import Topology +from pymongo.errors import ServerSelectionTimeoutError +from pymongo.hello import HelloCompat +from pymongo.operations import _Op +from pymongo.server_selectors import writable_server_selector +from pymongo.typings import strip_optional + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.asynchronous.utils import async_wait_until +from test.asynchronous.utils_selection_tests import ( + create_selection_tests, + get_topology_settings_dict, +) +from test.utils_selection_tests_shared import ( + get_addresses, + make_server_description, +) +from test.utils_shared import ( + FunctionCallRecorder, + OvertCommandListener, +) + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join( + Path(__file__).resolve().parent, "server_selection", "server_selection" + ) +else: + TEST_PATH = os.path.join( + Path(__file__).resolve().parent.parent, "server_selection", "server_selection" + ) + + +class SelectionStoreSelector: + """No-op selector that keeps track of what was passed to it.""" + + def __init__(self): + self.selection = None + + def __call__(self, selection): + self.selection = selection + return selection + + +class TestAllScenarios(create_selection_tests(TEST_PATH)): # type: ignore + pass + + +class TestCustomServerSelectorFunction(AsyncIntegrationTest): + @async_client_context.require_replica_set + async def test_functional_select_max_port_number_host(self): + # Selector that returns server with highest port number. + def custom_selector(servers): + ports = [s.address[1] for s in servers] + idx = ports.index(max(ports)) + return [servers[idx]] + + # Initialize client with appropriate listeners. + listener = OvertCommandListener() + client = await self.async_rs_or_single_client( + server_selector=custom_selector, event_listeners=[listener] + ) + coll = client.get_database("testdb", read_preference=ReadPreference.NEAREST).coll + self.addAsyncCleanup(client.drop_database, "testdb") + + # Wait the node list to be fully populated. + async def all_hosts_started(): + return len((await client.admin.command(HelloCompat.LEGACY_CMD))["hosts"]) == len( + client._topology._description.readable_servers + ) + + await async_wait_until(all_hosts_started, "receive heartbeat from all hosts") + + expected_port = max( + [strip_optional(n.address[1]) for n in client._topology._description.readable_servers] + ) + + # Insert 1 record and access it 10 times. + await coll.insert_one({"name": "John Doe"}) + for _ in range(10): + await coll.find_one({"name": "John Doe"}) + + # Confirm all find commands are run against appropriate host. + for command in listener.started_events: + if command.command_name == "find": + self.assertEqual(command.connection_id[1], expected_port) + + async def test_invalid_server_selector(self): + # Client initialization must fail if server_selector is not callable. + for selector_candidate in [[], 10, "string", {}]: + with self.assertRaisesRegex(ValueError, "must be a callable"): + AsyncMongoClient(connect=False, server_selector=selector_candidate) + + # None value for server_selector is OK. + AsyncMongoClient(connect=False, server_selector=None) + + @async_client_context.require_replica_set + async def test_selector_called(self): + selector = FunctionCallRecorder(lambda x: x) + + # Client setup. + mongo_client = await self.async_rs_or_single_client(server_selector=selector) + test_collection = mongo_client.testdb.test_collection + self.addAsyncCleanup(mongo_client.drop_database, "testdb") + + # Do N operations and test selector is called at least N times. + await test_collection.insert_one({"age": 20, "name": "John"}) + await test_collection.insert_one({"age": 31, "name": "Jane"}) + await test_collection.update_one({"name": "Jane"}, {"$set": {"age": 21}}) + await test_collection.find_one({"name": "Roe"}) + self.assertGreaterEqual(selector.call_count, 4) + + @async_client_context.require_replica_set + async def test_latency_threshold_application(self): + selector = SelectionStoreSelector() + + scenario_def: dict = { + "topology_description": { + "type": "ReplicaSetWithPrimary", + "servers": [ + {"address": "b:27017", "avg_rtt_ms": 10000, "type": "RSSecondary", "tag": {}}, + {"address": "c:27017", "avg_rtt_ms": 20000, "type": "RSSecondary", "tag": {}}, + {"address": "a:27017", "avg_rtt_ms": 30000, "type": "RSPrimary", "tag": {}}, + ], + } + } + + # Create & populate Topology such that all but one server is too slow. + rtt_times = [srv["avg_rtt_ms"] for srv in scenario_def["topology_description"]["servers"]] + min_rtt_idx = rtt_times.index(min(rtt_times)) + seeds, hosts = get_addresses(scenario_def["topology_description"]["servers"]) + settings = get_topology_settings_dict( + heartbeat_frequency=1, local_threshold_ms=1, seeds=seeds, server_selector=selector + ) + topology = Topology(TopologySettings(**settings)) + await topology.open() + for server in scenario_def["topology_description"]["servers"]: + server_description = make_server_description(server, hosts) + await topology.on_change(server_description) + + # Invoke server selection and assert no filtering based on latency + # prior to custom server selection logic kicking in. + server = await topology.select_server(ReadPreference.NEAREST, _Op.TEST) + assert selector.selection is not None + self.assertEqual(len(selector.selection), len(topology.description.server_descriptions())) + + # Ensure proper filtering based on latency after custom selection. + self.assertEqual(server.description.address, seeds[min_rtt_idx]) + + @async_client_context.require_replica_set + async def test_server_selector_bypassed(self): + selector = FunctionCallRecorder(lambda x: x) + + scenario_def = { + "topology_description": { + "type": "ReplicaSetNoPrimary", + "servers": [ + {"address": "b:27017", "avg_rtt_ms": 10000, "type": "RSSecondary", "tag": {}}, + {"address": "c:27017", "avg_rtt_ms": 20000, "type": "RSSecondary", "tag": {}}, + {"address": "a:27017", "avg_rtt_ms": 30000, "type": "RSSecondary", "tag": {}}, + ], + } + } + + # Create & populate Topology such that no server is writeable. + seeds, hosts = get_addresses(scenario_def["topology_description"]["servers"]) + settings = get_topology_settings_dict( + heartbeat_frequency=1, local_threshold_ms=1, seeds=seeds, server_selector=selector + ) + topology = Topology(TopologySettings(**settings)) + await topology.open() + for server in scenario_def["topology_description"]["servers"]: + server_description = make_server_description(server, hosts) + await topology.on_change(server_description) + + # Invoke server selection and assert no calls to our custom selector. + with self.assertRaisesRegex(ServerSelectionTimeoutError, "No primary available for writes"): + await topology.select_server( + writable_server_selector, _Op.TEST, server_selection_timeout=0.1 + ) + self.assertEqual(selector.call_count, 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_server_selection_in_window.py b/test/asynchronous/test_server_selection_in_window.py new file mode 100644 index 0000000000..3fe448d4dd --- /dev/null +++ b/test/asynchronous/test_server_selection_in_window.py @@ -0,0 +1,178 @@ +# Copyright 2020-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the topology module's Server Selection Spec implementation.""" +from __future__ import annotations + +import asyncio +import os +import threading +from pathlib import Path +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.asynchronous.helpers import ConcurrentRunner +from test.asynchronous.utils_selection_tests import create_topology +from test.asynchronous.utils_spec_runner import AsyncSpecTestCreator +from test.utils_shared import ( + CMAPListener, + OvertCommandListener, + async_wait_until, +) + +from pymongo.common import clean_node +from pymongo.monitoring import ConnectionReadyEvent +from pymongo.operations import _Op +from pymongo.read_preferences import ReadPreference + +_IS_SYNC = False +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection", "in_window") +else: + TEST_PATH = os.path.join( + Path(__file__).resolve().parent.parent, "server_selection", "in_window" + ) + + +class TestAllScenarios(unittest.IsolatedAsyncioTestCase): + async def run_scenario(self, scenario_def): + topology = await create_topology(scenario_def) + + # Update mock operation_count state: + for mock in scenario_def["mocked_topology_state"]: + address = clean_node(mock["address"]) + server = topology.get_server_by_address(address) + server.pool.operation_count = mock["operation_count"] + + pref = ReadPreference.NEAREST + counts = {address: 0 for address in topology._description.server_descriptions()} + + # Number of times to repeat server selection + iterations = scenario_def["iterations"] + for _ in range(iterations): + server = await topology.select_server(pref, _Op.TEST, server_selection_timeout=0) + counts[server.description.address] += 1 + + # Verify expected_frequencies + outcome = scenario_def["outcome"] + tolerance = outcome["tolerance"] + expected_frequencies = outcome["expected_frequencies"] + for host_str, freq in expected_frequencies.items(): + address = clean_node(host_str) + actual_freq = float(counts[address]) / iterations + if freq == 0: + # Should be exactly 0. + self.assertEqual(actual_freq, 0) + else: + # Should be within 'tolerance'. + self.assertAlmostEqual(actual_freq, freq, delta=tolerance) + + +def create_test(scenario_def, test, name): + async def run_scenario(self): + await self.run_scenario(scenario_def) + + return run_scenario + + +class CustomSpecTestCreator(AsyncSpecTestCreator): + def tests(self, scenario_def): + """Extract the tests from a spec file. + + Server selection in_window tests do not have a 'tests' field. + The whole file represents a single test case. + """ + return [scenario_def] + + +CustomSpecTestCreator(create_test, TestAllScenarios, TEST_PATH).create_tests() + + +class FinderTask(ConcurrentRunner): + def __init__(self, collection, iterations): + super().__init__() + self.daemon = True + self.collection = collection + self.iterations = iterations + self.passed = False + + async def run(self): + for _ in range(self.iterations): + await self.collection.find_one({}) + self.passed = True + + +class TestProse(AsyncIntegrationTest): + async def frequencies(self, client, listener, n_finds=10): + coll = client.test.test + N_TASKS = 10 + tasks = [FinderTask(coll, n_finds) for _ in range(N_TASKS)] + for task in tasks: + await task.start() + for task in tasks: + await task.join() + for task in tasks: + self.assertTrue(task.passed) + + events = listener.started_events + self.assertEqual(len(events), n_finds * N_TASKS) + nodes = client.nodes + self.assertEqual(len(nodes), 2) + freqs = {address: 0.0 for address in nodes} + for event in events: + freqs[event.connection_id] += 1 + for address in freqs: + freqs[address] = freqs[address] / float(len(events)) + return freqs + + @async_client_context.require_failCommand_appName + @async_client_context.require_multiple_mongoses + async def test_load_balancing(self): + listener = OvertCommandListener() + cmap_listener = CMAPListener() + # PYTHON-2584: Use a large localThresholdMS to avoid the impact of + # varying RTTs. + client = await self.async_rs_client( + async_client_context.mongos_seeds(), + appName="loadBalancingTest", + event_listeners=[listener, cmap_listener], + localThresholdMS=30000, + minPoolSize=10, + ) + await async_wait_until(lambda: len(client.nodes) == 2, "discover both nodes") + # Wait for both pools to be populated. + await cmap_listener.async_wait_for_event(ConnectionReadyEvent, 20) + # Delay find commands on only one mongos. + delay_finds = { + "configureFailPoint": "failCommand", + "mode": {"times": 10000}, + "data": { + "failCommands": ["find"], + "blockConnection": True, + "blockTimeMS": 500, + "appName": "loadBalancingTest", + }, + } + async with self.fail_point(delay_finds): + nodes = async_client_context.client.nodes + self.assertEqual(len(nodes), 1) + delayed_server = next(iter(nodes)) + freqs = await self.frequencies(client, listener) + self.assertLessEqual(freqs[delayed_server], 0.25) + listener.reset() + freqs = await self.frequencies(client, listener, n_finds=150) + self.assertAlmostEqual(freqs[delayed_server], 0.50, delta=0.15) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_server_selection_logging.py b/test/asynchronous/test_server_selection_logging.py new file mode 100644 index 0000000000..6b0975318a --- /dev/null +++ b/test/asynchronous/test_server_selection_logging.py @@ -0,0 +1,45 @@ +# Copyright 2024-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run the server selection logging unified format spec tests.""" +from __future__ import annotations + +import os +import sys +from pathlib import Path + +sys.path[0:0] = [""] + +from test import unittest +from test.asynchronous.unified_format import generate_test_classes + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection_logging") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "server_selection_logging") + + +globals().update( + generate_test_classes( + TEST_PATH, + module=__name__, + ) +) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_server_selection_rtt.py b/test/asynchronous/test_server_selection_rtt.py new file mode 100644 index 0000000000..1f8f6bc7df --- /dev/null +++ b/test/asynchronous/test_server_selection_rtt.py @@ -0,0 +1,77 @@ +# Copyright 2015 MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the topology module.""" +from __future__ import annotations + +import json +import os +import sys +from pathlib import Path + +sys.path[0:0] = [""] + +from test import unittest +from test.asynchronous import AsyncPyMongoTestCase + +from pymongo.read_preferences import MovingAverage + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection/rtt") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "server_selection/rtt") + + +class TestAllScenarios(AsyncPyMongoTestCase): + pass + + +def create_test(scenario_def): + def run_scenario(self): + moving_average = MovingAverage() + + if scenario_def["avg_rtt_ms"] != "NULL": + moving_average.add_sample(scenario_def["avg_rtt_ms"]) + + if scenario_def["new_rtt_ms"] != "NULL": + moving_average.add_sample(scenario_def["new_rtt_ms"]) + + self.assertAlmostEqual(moving_average.get(), scenario_def["new_avg_rtt"]) + + return run_scenario + + +def create_tests(): + for dirpath, _, filenames in os.walk(TEST_PATH): + dirname = os.path.split(dirpath)[-1] + + for filename in filenames: + with open(os.path.join(dirpath, filename)) as scenario_stream: + scenario_def = json.load(scenario_stream) + + # Construct test from scenario. + new_test = create_test(scenario_def) + test_name = f"test_{dirname}_{os.path.splitext(filename)[0]}" + + new_test.__name__ = test_name + setattr(TestAllScenarios, new_test.__name__, new_test) + + +create_tests() + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_session.py b/test/asynchronous/test_session.py index 42bc253b56..3655f49aab 100644 --- a/test/asynchronous/test_session.py +++ b/test/asynchronous/test_session.py @@ -15,10 +15,13 @@ """Test the client_session module.""" from __future__ import annotations +import asyncio import copy import sys import time +from asyncio import iscoroutinefunction from io import BytesIO +from test.asynchronous.helpers import ExceptionCatchingTask from typing import Any, Callable, List, Set, Tuple from pymongo.synchronous.mongo_client import MongoClient @@ -27,22 +30,22 @@ from test.asynchronous import ( AsyncIntegrationTest, - AsyncPyMongoTestCase, AsyncUnitTest, SkipTest, async_client_context, unittest, ) -from test.utils import ( +from test.asynchronous.helpers import client_knobs +from test.utils_shared import ( EventListener, - ExceptionCatchingThread, + HeartbeatEventListener, OvertCommandListener, async_wait_until, ) from bson import DBRef from gridfs.asynchronous.grid_file import AsyncGridFS, AsyncGridFSBucket -from pymongo import ASCENDING, AsyncMongoClient, monitoring +from pymongo import ASCENDING, AsyncMongoClient, _csot, monitoring from pymongo.asynchronous.command_cursor import AsyncCommandCursor from pymongo.asynchronous.cursor import AsyncCursor from pymongo.asynchronous.helpers import anext @@ -184,8 +187,7 @@ async def _test_ops(self, client, *ops): f"{f.__name__} did not return implicit session to pool", ) - @async_client_context.require_sync - def test_implicit_sessions_checkout(self): + async def test_implicit_sessions_checkout(self): # "To confirm that implicit sessions only allocate their server session after a # successful connection checkout" test from Driver Sessions Spec. succeeded = False @@ -193,7 +195,7 @@ def test_implicit_sessions_checkout(self): failures = 0 for _ in range(5): listener = OvertCommandListener() - client = self.async_rs_or_single_client(event_listeners=[listener], maxPoolSize=1) + client = await self.async_rs_or_single_client(event_listeners=[listener], maxPoolSize=1) cursor = client.db.test.find({}) ops: List[Tuple[Callable, List[Any]]] = [ (client.db.test.find_one, [{"_id": 1}]), @@ -210,26 +212,27 @@ def test_implicit_sessions_checkout(self): (cursor.distinct, ["_id"]), (client.db.list_collections, []), ] - threads = [] + tasks = [] listener.reset() - def thread_target(op, *args): - res = op(*args) + async def target(op, *args): + if iscoroutinefunction(op): + res = await op(*args) + else: + res = op(*args) if isinstance(res, (AsyncCursor, AsyncCommandCursor)): - list(res) # type: ignore[call-overload] + await res.to_list() for op, args in ops: - threads.append( - ExceptionCatchingThread( - target=thread_target, args=[op, *args], name=op.__name__ - ) + tasks.append( + ExceptionCatchingTask(target=target, args=[op, *args], name=op.__name__) ) - threads[-1].start() - self.assertEqual(len(threads), len(ops)) - for thread in threads: - thread.join() - self.assertIsNone(thread.exc) - client.close() + await tasks[-1].start() + self.assertEqual(len(tasks), len(ops)) + for t in tasks: + await t.join() + self.assertIsNone(t.exc) + await client.close() lsid_set.clear() for i in listener.started_events: if i.command.get("lsid"): @@ -538,9 +541,10 @@ async def find(session=None): (bucket.download_to_stream_by_name, ["f", sio], {}), (find, [], {}), (bucket.rename, [1, "f2"], {}), + (bucket.rename_by_name, ["f2", "f3"], {}), # Delete both files so _test_ops can run these operations twice. (bucket.delete, [1], {}), - (bucket.delete, [2], {}), + (bucket.delete_by_name, ["f"], {}), ) async def test_gridfsbucket_cursor(self): @@ -1133,12 +1137,10 @@ async def asyncSetUp(self): if "$clusterTime" not in (await async_client_context.hello): raise SkipTest("$clusterTime not supported") + # Sessions prose test: 3) $clusterTime in commands async def test_cluster_time(self): listener = SessionTestListener() - # Prevent heartbeats from updating $clusterTime between operations. - client = await self.async_rs_or_single_client( - event_listeners=[listener], heartbeatFrequencyMS=999999 - ) + client = await self.async_rs_or_single_client(event_listeners=[listener]) collection = client.pymongo_test.collection # Prepare for tests of find() and aggregate(). await collection.insert_many([{} for _ in range(10)]) @@ -1217,6 +1219,40 @@ async def aggregate(): f"{f.__name__} sent wrong $clusterTime with {event.command_name}", ) + # Sessions prose test: 20) Drivers do not gossip `$clusterTime` on SDAM commands + async def test_cluster_time_not_used_by_sdam(self): + heartbeat_listener = HeartbeatEventListener() + cmd_listener = OvertCommandListener() + with client_knobs(min_heartbeat_interval=0.01): + c1 = await self.async_single_client( + event_listeners=[heartbeat_listener, cmd_listener], heartbeatFrequencyMS=10 + ) + cluster_time = (await c1.admin.command({"ping": 1}))["$clusterTime"] + self.assertEqual(c1._topology.max_cluster_time(), cluster_time) + + # Advance the server's $clusterTime by performing an insert via another client. + await self.db.test.insert_one({"advance": "$clusterTime"}) + # Wait until the client C1 processes the next pair of SDAM heartbeat started + succeeded events. + heartbeat_listener.reset() + + async def next_heartbeat(): + events = heartbeat_listener.events + for i in range(len(events) - 1): + if isinstance(events[i], monitoring.ServerHeartbeatStartedEvent): + if isinstance(events[i + 1], monitoring.ServerHeartbeatSucceededEvent): + return True + return False + + await async_wait_until( + next_heartbeat, "never found pair of heartbeat started + succeeded events" + ) + # Assert that C1's max $clusterTime is still the same and has not been updated by SDAM. + cmd_listener.reset() + await c1.admin.command({"ping": 1}) + started = cmd_listener.started_events[0] + self.assertEqual(started.command_name, "ping") + self.assertEqual(started.command["$clusterTime"], cluster_time) + if __name__ == "__main__": unittest.main() diff --git a/test/asynchronous/test_sessions_unified.py b/test/asynchronous/test_sessions_unified.py new file mode 100644 index 0000000000..b4cbac5704 --- /dev/null +++ b/test/asynchronous/test_sessions_unified.py @@ -0,0 +1,40 @@ +# Copyright 2021-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the Sessions unified spec tests.""" +from __future__ import annotations + +import os +import sys +from pathlib import Path + +sys.path[0:0] = [""] + +from test import unittest +from test.asynchronous.unified_format import generate_test_classes + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "sessions") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "sessions") + + +# Generate unified tests. +globals().update(generate_test_classes(TEST_PATH, module=__name__)) + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_srv_polling.py b/test/asynchronous/test_srv_polling.py new file mode 100644 index 0000000000..3dcd21ef1d --- /dev/null +++ b/test/asynchronous/test_srv_polling.py @@ -0,0 +1,367 @@ +# Copyright 2019-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run the SRV support tests.""" +from __future__ import annotations + +import asyncio +import sys +import time +from test.utils_shared import FunctionCallRecorder +from typing import Any + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncPyMongoTestCase, client_knobs, unittest +from test.asynchronous.utils import async_wait_until + +import pymongo +from pymongo import common +from pymongo.asynchronous.srv_resolver import _have_dnspython +from pymongo.errors import ConfigurationError + +_IS_SYNC = False + +WAIT_TIME = 0.1 + + +class SrvPollingKnobs: + def __init__( + self, + ttl_time=None, + min_srv_rescan_interval=None, + nodelist_callback=None, + count_resolver_calls=False, + ): + self.ttl_time = ttl_time + self.min_srv_rescan_interval = min_srv_rescan_interval + self.nodelist_callback = nodelist_callback + self.count_resolver_calls = count_resolver_calls + + self.old_min_srv_rescan_interval = None + self.old_dns_resolver_response = None + + def enable(self): + self.old_min_srv_rescan_interval = common.MIN_SRV_RESCAN_INTERVAL + self.old_dns_resolver_response = ( + pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl + ) + + if self.min_srv_rescan_interval is not None: + common.MIN_SRV_RESCAN_INTERVAL = self.min_srv_rescan_interval + + async def mock_get_hosts_and_min_ttl(resolver, *args): + assert self.old_dns_resolver_response is not None + nodes, ttl = await self.old_dns_resolver_response(resolver) + if self.nodelist_callback is not None: + nodes = self.nodelist_callback() + if self.ttl_time is not None: + ttl = self.ttl_time + return nodes, ttl + + patch_func: Any + if self.count_resolver_calls: + patch_func = FunctionCallRecorder(mock_get_hosts_and_min_ttl) + else: + patch_func = mock_get_hosts_and_min_ttl + + pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func # type: ignore + + def __enter__(self): + self.enable() + + def disable(self): + common.MIN_SRV_RESCAN_INTERVAL = self.old_min_srv_rescan_interval # type: ignore + pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl = ( # type: ignore + self.old_dns_resolver_response + ) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.disable() + + +class TestSrvPolling(AsyncPyMongoTestCase): + BASE_SRV_RESPONSE = [ + ("localhost.test.build.10gen.cc", 27017), + ("localhost.test.build.10gen.cc", 27018), + ] + + CONNECTION_STRING = "mongodb+srv://test1.test.build.10gen.cc" + + async def asyncSetUp(self): + # Patch timeouts to ensure short rescan SRV interval. + self.client_knobs = client_knobs( + heartbeat_frequency=WAIT_TIME, + min_heartbeat_interval=WAIT_TIME, + events_queue_frequency=WAIT_TIME, + ) + self.client_knobs.enable() + + async def asyncTearDown(self): + self.client_knobs.disable() + + def get_nodelist(self, client): + return client._topology.description.server_descriptions().keys() + + async def assert_nodelist_change(self, expected_nodelist, client, timeout=(100 * WAIT_TIME)): + """Check if the client._topology eventually sees all nodes in the + expected_nodelist. + """ + + def predicate(): + nodelist = self.get_nodelist(client) + if set(expected_nodelist) == set(nodelist): + return True + return False + + await async_wait_until(predicate, "see expected nodelist", timeout=timeout) + + async def assert_nodelist_nochange(self, expected_nodelist, client, timeout=(100 * WAIT_TIME)): + """Check if the client._topology ever deviates from seeing all nodes + in the expected_nodelist. Consistency is checked after sleeping for + (WAIT_TIME * 10) seconds. Also check that the resolver is called at + least once. + """ + + def predicate(): + if set(expected_nodelist) == set(self.get_nodelist(client)): + return ( + pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count + >= 1 + ) + return False + + await async_wait_until(predicate, "Node list equals expected nodelist", timeout=timeout) + + nodelist = self.get_nodelist(client) + if set(expected_nodelist) != set(nodelist): + msg = "Client nodelist %s changed unexpectedly (expected %s)" + raise self.fail(msg % (nodelist, expected_nodelist)) + self.assertGreaterEqual( + pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, # type: ignore + 1, + "resolver was never called", + ) + return True + + async def run_scenario(self, dns_response, expect_change): + self.assertEqual(_have_dnspython(), True) + if callable(dns_response): + dns_resolver_response = dns_response + else: + + def dns_resolver_response(): + return dns_response + + if expect_change: + assertion_method = self.assert_nodelist_change + count_resolver_calls = False + expected_response = dns_response + else: + assertion_method = self.assert_nodelist_nochange + count_resolver_calls = True + expected_response = self.BASE_SRV_RESPONSE + + # Patch timeouts to ensure short test running times. + with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): + client = self.simple_client(self.CONNECTION_STRING) + await client.aconnect() + await self.assert_nodelist_change(self.BASE_SRV_RESPONSE, client) + # Patch list of hosts returned by DNS query. + with SrvPollingKnobs( + nodelist_callback=dns_resolver_response, count_resolver_calls=count_resolver_calls + ): + await assertion_method(expected_response, client) + + async def test_addition(self): + response = self.BASE_SRV_RESPONSE[:] + response.append(("localhost.test.build.10gen.cc", 27019)) + await self.run_scenario(response, True) + + async def test_removal(self): + response = self.BASE_SRV_RESPONSE[:] + response.remove(("localhost.test.build.10gen.cc", 27018)) + await self.run_scenario(response, True) + + async def test_replace_one(self): + response = self.BASE_SRV_RESPONSE[:] + response.remove(("localhost.test.build.10gen.cc", 27018)) + response.append(("localhost.test.build.10gen.cc", 27019)) + await self.run_scenario(response, True) + + async def test_replace_both_with_one(self): + response = [("localhost.test.build.10gen.cc", 27019)] + await self.run_scenario(response, True) + + async def test_replace_both_with_two(self): + response = [ + ("localhost.test.build.10gen.cc", 27019), + ("localhost.test.build.10gen.cc", 27020), + ] + await self.run_scenario(response, True) + + async def test_dns_failures(self): + from dns import exception + + for exc in (exception.FormError, exception.TooBig, exception.Timeout): + + def response_callback(*args): + raise exc("DNS Failure!") + + await self.run_scenario(response_callback, False) + + async def test_dns_record_lookup_empty(self): + response: list = [] + await self.run_scenario(response, False) + + async def _test_recover_from_initial(self, initial_callback): + # Construct a valid final response callback distinct from base. + response_final = self.BASE_SRV_RESPONSE[:] + response_final.pop() + + def final_callback(): + return response_final + + with SrvPollingKnobs( + ttl_time=WAIT_TIME, + min_srv_rescan_interval=WAIT_TIME, + nodelist_callback=initial_callback, + count_resolver_calls=True, + ): + # Client uses unpatched method to get initial nodelist + client = self.simple_client(self.CONNECTION_STRING) + await client.aconnect() + # Invalid DNS resolver response should not change nodelist. + await self.assert_nodelist_nochange(self.BASE_SRV_RESPONSE, client) + + with SrvPollingKnobs( + ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME, nodelist_callback=final_callback + ): + # Nodelist should reflect new valid DNS resolver response. + await self.assert_nodelist_change(response_final, client) + + async def test_recover_from_initially_empty_seedlist(self): + def empty_seedlist(): + return [] + + await self._test_recover_from_initial(empty_seedlist) + + async def test_recover_from_initially_erroring_seedlist(self): + def erroring_seedlist(): + raise ConfigurationError + + await self._test_recover_from_initial(erroring_seedlist) + + async def test_10_all_dns_selected(self): + response = [ + ("localhost.test.build.10gen.cc", 27017), + ("localhost.test.build.10gen.cc", 27019), + ("localhost.test.build.10gen.cc", 27020), + ] + + def nodelist_callback(): + return response + + with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): + client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=0) + await client.aconnect() + with SrvPollingKnobs(nodelist_callback=nodelist_callback): + await self.assert_nodelist_change(response, client) + + async def test_11_all_dns_selected(self): + response = [ + ("localhost.test.build.10gen.cc", 27019), + ("localhost.test.build.10gen.cc", 27020), + ] + + def nodelist_callback(): + return response + + with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): + client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=2) + await client.aconnect() + with SrvPollingKnobs(nodelist_callback=nodelist_callback): + await self.assert_nodelist_change(response, client) + + async def test_12_new_dns_randomly_selected(self): + response = [ + ("localhost.test.build.10gen.cc", 27020), + ("localhost.test.build.10gen.cc", 27019), + ("localhost.test.build.10gen.cc", 27017), + ] + + def nodelist_callback(): + return response + + with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): + client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=2) + await client.aconnect() + with SrvPollingKnobs(nodelist_callback=nodelist_callback): + await asyncio.sleep(2 * common.MIN_SRV_RESCAN_INTERVAL) + final_topology = set(client.topology_description.server_descriptions()) + self.assertIn(("localhost.test.build.10gen.cc", 27017), final_topology) + self.assertEqual(len(final_topology), 2) + + async def test_does_not_flipflop(self): + with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): + client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=1) + await client.aconnect() + old = set(client.topology_description.server_descriptions()) + await asyncio.sleep(4 * WAIT_TIME) + new = set(client.topology_description.server_descriptions()) + self.assertSetEqual(old, new) + + async def test_srv_service_name(self): + # Construct a valid final response callback distinct from base. + response = [ + ("localhost.test.build.10gen.cc.", 27019), + ("localhost.test.build.10gen.cc.", 27020), + ] + + def nodelist_callback(): + return response + + with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): + client = self.simple_client( + "mongodb+srv://test22.test.build.10gen.cc/?srvServiceName=customname" + ) + await client.aconnect() + with SrvPollingKnobs(nodelist_callback=nodelist_callback): + await self.assert_nodelist_change(response, client) + + async def test_srv_waits_to_poll(self): + modified = [("localhost.test.build.10gen.cc", 27019)] + + def resolver_response(): + return modified + + with SrvPollingKnobs( + ttl_time=WAIT_TIME, + min_srv_rescan_interval=WAIT_TIME, + nodelist_callback=resolver_response, + ): + client = self.simple_client(self.CONNECTION_STRING) + await client.aconnect() + with self.assertRaises(AssertionError): + await self.assert_nodelist_change(modified, client, timeout=WAIT_TIME / 2) + + def test_import_dns_resolver(self): + # Regression test for PYTHON-4407 + import dns.resolver + + self.assertTrue(hasattr(dns.resolver, "resolve")) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_ssl.py b/test/asynchronous/test_ssl.py new file mode 100644 index 0000000000..582e3f9267 --- /dev/null +++ b/test/asynchronous/test_ssl.py @@ -0,0 +1,674 @@ +# Copyright 2011-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for SSL support.""" +from __future__ import annotations + +import os +import pathlib +import socket +import sys + +sys.path[0:0] = [""] + +from test.asynchronous import ( + HAVE_IPADDRESS, + AsyncIntegrationTest, + AsyncPyMongoTestCase, + SkipTest, + async_client_context, + connected, + remove_all_users, + unittest, +) +from test.utils_shared import ( + EventListener, + OvertCommandListener, + cat_files, + ignore_deprecations, +) +from urllib.parse import quote_plus + +from pymongo import AsyncMongoClient, ssl_support +from pymongo.errors import ConfigurationError, ConnectionFailure, OperationFailure +from pymongo.hello import HelloCompat +from pymongo.ssl_support import HAVE_PYSSL, HAVE_SSL, _ssl, get_ssl_context +from pymongo.write_concern import WriteConcern + +_HAVE_PYOPENSSL = False +try: + # All of these must be available to use PyOpenSSL + import OpenSSL + import requests + import service_identity + + # Ensure service_identity>=18.1 is installed + from service_identity.pyopenssl import verify_ip_address + + from pymongo.ocsp_support import _load_trusted_ca_certs + + _HAVE_PYOPENSSL = True +except ImportError: + _load_trusted_ca_certs = None # type: ignore + + +if HAVE_SSL: + import ssl + +_IS_SYNC = False + +if _IS_SYNC: + CERT_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "certificates") +else: + CERT_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "certificates") + +CLIENT_PEM = os.path.join(CERT_PATH, "client.pem") +CLIENT_ENCRYPTED_PEM = os.path.join(CERT_PATH, "password_protected.pem") +CA_PEM = os.path.join(CERT_PATH, "ca.pem") +CA_BUNDLE_PEM = os.path.join(CERT_PATH, "trusted-ca.pem") +CRL_PEM = os.path.join(CERT_PATH, "crl.pem") +MONGODB_X509_USERNAME = "C=US,ST=New York,L=New York City,O=MDB,OU=Drivers,CN=client" + +# To fully test this start a mongod instance (built with SSL support) like so: +# mongod --dbpath /path/to/data/directory --sslOnNormalPorts \ +# --sslPEMKeyFile /path/to/pymongo/test/certificates/server.pem \ +# --sslCAFile /path/to/pymongo/test/certificates/ca.pem \ +# --sslWeakCertificateValidation +# Also, make sure you have 'server' as an alias for localhost in /etc/hosts +# +# Note: For all replica set tests to pass, the replica set configuration must +# use 'localhost' for the hostname of all hosts. + + +class TestClientSSL(AsyncPyMongoTestCase): + @unittest.skipIf(HAVE_SSL, "The ssl module is available, can't test what happens without it.") + def test_no_ssl_module(self): + # Explicit + self.assertRaises(ConfigurationError, self.simple_client, ssl=True) + + # Implied + self.assertRaises(ConfigurationError, self.simple_client, tlsCertificateKeyFile=CLIENT_PEM) + + @unittest.skipUnless(HAVE_SSL, "The ssl module is not available.") + @ignore_deprecations + def test_config_ssl(self): + # Tests various ssl configurations + self.assertRaises(ValueError, self.simple_client, ssl="foo") + self.assertRaises( + ConfigurationError, self.simple_client, tls=False, tlsCertificateKeyFile=CLIENT_PEM + ) + self.assertRaises(TypeError, self.simple_client, ssl=0) + self.assertRaises(TypeError, self.simple_client, ssl=5.5) + self.assertRaises(TypeError, self.simple_client, ssl=[]) + + self.assertRaises(IOError, self.simple_client, tlsCertificateKeyFile="NoSuchFile") + self.assertRaises(TypeError, self.simple_client, tlsCertificateKeyFile=True) + self.assertRaises(TypeError, self.simple_client, tlsCertificateKeyFile=[]) + + # Test invalid combinations + self.assertRaises( + ConfigurationError, self.simple_client, tls=False, tlsCertificateKeyFile=CLIENT_PEM + ) + self.assertRaises(ConfigurationError, self.simple_client, tls=False, tlsCAFile=CA_PEM) + self.assertRaises(ConfigurationError, self.simple_client, tls=False, tlsCRLFile=CRL_PEM) + self.assertRaises( + ConfigurationError, self.simple_client, tls=False, tlsAllowInvalidCertificates=False + ) + self.assertRaises( + ConfigurationError, self.simple_client, tls=False, tlsAllowInvalidHostnames=False + ) + self.assertRaises( + ConfigurationError, self.simple_client, tls=False, tlsDisableOCSPEndpointCheck=False + ) + + @unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.") + def test_use_pyopenssl_when_available(self): + self.assertTrue(HAVE_PYSSL) + + @unittest.skipUnless(_HAVE_PYOPENSSL, "Cannot test without PyOpenSSL") + def test_load_trusted_ca_certs(self): + trusted_ca_certs = _load_trusted_ca_certs(CA_BUNDLE_PEM) + self.assertEqual(2, len(trusted_ca_certs)) + + +class TestSSL(AsyncIntegrationTest): + saved_port: int + + async def assertClientWorks(self, client): + coll = client.pymongo_test.ssl_test.with_options( + write_concern=WriteConcern(w=async_client_context.w) + ) + await coll.drop() + await coll.insert_one({"ssl": True}) + self.assertTrue((await coll.find_one())["ssl"]) + await coll.drop() + + @unittest.skipUnless(HAVE_SSL, "The ssl module is not available.") + async def asyncSetUp(self): + await super().asyncSetUp() + # MongoClient should connect to the primary by default. + self.saved_port = AsyncMongoClient.PORT + AsyncMongoClient.PORT = await async_client_context.port + + async def asyncTearDown(self): + AsyncMongoClient.PORT = self.saved_port + + @async_client_context.require_tls + async def test_simple_ssl(self): + # Expects the server to be running with ssl and with + # no --sslPEMKeyFile or with --sslWeakCertificateValidation + await self.assertClientWorks(self.client) + + @async_client_context.require_tlsCertificateKeyFile + @ignore_deprecations + async def test_tlsCertificateKeyFilePassword(self): + # Expects the server to be running with server.pem and ca.pem + # + # --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem + # --sslCAFile=/path/to/pymongo/test/certificates/ca.pem + if not hasattr(ssl, "SSLContext") and not HAVE_PYSSL: + self.assertRaises( + ConfigurationError, + self.simple_client, + "localhost", + ssl=True, + tlsCertificateKeyFile=CLIENT_ENCRYPTED_PEM, + tlsCertificateKeyFilePassword="qwerty", + tlsCAFile=CA_PEM, + serverSelectionTimeoutMS=1000, + ) + else: + await connected( + self.simple_client( + "localhost", + ssl=True, + tlsCertificateKeyFile=CLIENT_ENCRYPTED_PEM, + tlsCertificateKeyFilePassword="qwerty", + tlsCAFile=CA_PEM, + serverSelectionTimeoutMS=5000, + **self.credentials, # type: ignore[arg-type] + ) + ) + + uri_fmt = ( + "mongodb://localhost/?ssl=true" + "&tlsCertificateKeyFile=%s&tlsCertificateKeyFilePassword=qwerty" + "&tlsCAFile=%s&serverSelectionTimeoutMS=5000" + ) + await connected( + self.simple_client(uri_fmt % (CLIENT_ENCRYPTED_PEM, CA_PEM), **self.credentials) # type: ignore[arg-type] + ) + + @async_client_context.require_tlsCertificateKeyFile + @async_client_context.require_no_auth + @ignore_deprecations + async def test_cert_ssl_implicitly_set(self): + # Expects the server to be running with server.pem and ca.pem + # + # --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem + # --sslCAFile=/path/to/pymongo/test/certificates/ca.pem + # + + # test that setting tlsCertificateKeyFile causes ssl to be set to True + client = self.simple_client( + await async_client_context.host, + await async_client_context.port, + tlsAllowInvalidCertificates=True, + tlsCertificateKeyFile=CLIENT_PEM, + ) + response = await client.admin.command(HelloCompat.LEGACY_CMD) + if "setName" in response: + client = self.simple_client( + await async_client_context.pair, + replicaSet=response["setName"], + w=len(response["hosts"]), + tlsAllowInvalidCertificates=True, + tlsCertificateKeyFile=CLIENT_PEM, + ) + + await self.assertClientWorks(client) + + @async_client_context.require_tlsCertificateKeyFile + @async_client_context.require_no_auth + @ignore_deprecations + async def test_cert_ssl_validation(self): + # Expects the server to be running with server.pem and ca.pem + # + # --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem + # --sslCAFile=/path/to/pymongo/test/certificates/ca.pem + # + client = self.simple_client( + "localhost", + ssl=True, + tlsCertificateKeyFile=CLIENT_PEM, + tlsAllowInvalidCertificates=False, + tlsCAFile=CA_PEM, + ) + response = await client.admin.command(HelloCompat.LEGACY_CMD) + if "setName" in response: + if response["primary"].split(":")[0] != "localhost": + raise SkipTest( + "No hosts in the replicaset for 'localhost'. " + "Cannot validate hostname in the certificate" + ) + + client = self.simple_client( + "localhost", + replicaSet=response["setName"], + w=len(response["hosts"]), + ssl=True, + tlsCertificateKeyFile=CLIENT_PEM, + tlsAllowInvalidCertificates=False, + tlsCAFile=CA_PEM, + ) + + await self.assertClientWorks(client) + + if HAVE_IPADDRESS: + client = self.simple_client( + "127.0.0.1", + ssl=True, + tlsCertificateKeyFile=CLIENT_PEM, + tlsAllowInvalidCertificates=False, + tlsCAFile=CA_PEM, + ) + await self.assertClientWorks(client) + + @async_client_context.require_tlsCertificateKeyFile + @async_client_context.require_no_auth + @ignore_deprecations + async def test_cert_ssl_uri_support(self): + # Expects the server to be running with server.pem and ca.pem + # + # --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem + # --sslCAFile=/path/to/pymongo/test/certificates/ca.pem + # + uri_fmt = ( + "mongodb://localhost/?ssl=true&tlsCertificateKeyFile=%s&tlsAllowInvalidCertificates" + "=%s&tlsCAFile=%s&tlsAllowInvalidHostnames=false" + ) + client = self.simple_client(uri_fmt % (CLIENT_PEM, "true", CA_PEM)) + await self.assertClientWorks(client) + + @async_client_context.require_tlsCertificateKeyFile + @async_client_context.require_server_resolvable + @ignore_deprecations + async def test_cert_ssl_validation_hostname_matching(self): + # Expects the server to be running with server.pem and ca.pem + # + # --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem + # --sslCAFile=/path/to/pymongo/test/certificates/ca.pem + ctx = get_ssl_context(None, None, None, None, True, True, False, _IS_SYNC) + self.assertFalse(ctx.check_hostname) + ctx = get_ssl_context(None, None, None, None, True, False, False, _IS_SYNC) + self.assertFalse(ctx.check_hostname) + ctx = get_ssl_context(None, None, None, None, False, True, False, _IS_SYNC) + self.assertFalse(ctx.check_hostname) + ctx = get_ssl_context(None, None, None, None, False, False, False, _IS_SYNC) + self.assertTrue(ctx.check_hostname) + + response = await self.client.admin.command(HelloCompat.LEGACY_CMD) + + with self.assertRaises(ConnectionFailure): + await connected( + self.simple_client( + "server", + ssl=True, + tlsCertificateKeyFile=CLIENT_PEM, + tlsAllowInvalidCertificates=False, + tlsCAFile=CA_PEM, + serverSelectionTimeoutMS=500, + **self.credentials, # type: ignore[arg-type] + ) + ) + + await connected( + self.simple_client( + "server", + ssl=True, + tlsCertificateKeyFile=CLIENT_PEM, + tlsAllowInvalidCertificates=False, + tlsCAFile=CA_PEM, + tlsAllowInvalidHostnames=True, + serverSelectionTimeoutMS=500, + **self.credentials, # type: ignore[arg-type] + ) + ) + + if "setName" in response: + with self.assertRaises(ConnectionFailure): + await connected( + self.simple_client( + "server", + replicaSet=response["setName"], + ssl=True, + tlsCertificateKeyFile=CLIENT_PEM, + tlsAllowInvalidCertificates=False, + tlsCAFile=CA_PEM, + serverSelectionTimeoutMS=500, + **self.credentials, # type: ignore[arg-type] + ) + ) + + await connected( + self.simple_client( + "server", + replicaSet=response["setName"], + ssl=True, + tlsCertificateKeyFile=CLIENT_PEM, + tlsAllowInvalidCertificates=False, + tlsCAFile=CA_PEM, + tlsAllowInvalidHostnames=True, + serverSelectionTimeoutMS=500, + **self.credentials, # type: ignore[arg-type] + ) + ) + + @async_client_context.require_tlsCertificateKeyFile + @async_client_context.require_sync + @async_client_context.require_no_api_version + @ignore_deprecations + async def test_tlsCRLFile_support(self): + if not hasattr(ssl, "VERIFY_CRL_CHECK_LEAF") or HAVE_PYSSL: + self.assertRaises( + ConfigurationError, + self.simple_client, + "localhost", + ssl=True, + tlsCAFile=CA_PEM, + tlsCRLFile=CRL_PEM, + serverSelectionTimeoutMS=1000, + ) + else: + await connected( + self.simple_client( + "localhost", + ssl=True, + tlsCAFile=CA_PEM, + serverSelectionTimeoutMS=1000, + **self.credentials, # type: ignore[arg-type] + ) + ) + + with self.assertRaises(ConnectionFailure): + await connected( + self.simple_client( + "localhost", + ssl=True, + tlsCAFile=CA_PEM, + tlsCRLFile=CRL_PEM, + serverSelectionTimeoutMS=1000, + **self.credentials, # type: ignore[arg-type] + ) + ) + + uri_fmt = "mongodb://localhost/?ssl=true&tlsCAFile=%s&serverSelectionTimeoutMS=1000" + await connected(self.simple_client(uri_fmt % (CA_PEM,), **self.credentials)) # type: ignore + + uri_fmt = ( + "mongodb://localhost/?ssl=true&tlsCRLFile=%s" + "&tlsCAFile=%s&serverSelectionTimeoutMS=1000" + ) + with self.assertRaises(ConnectionFailure): + await connected( + self.simple_client(uri_fmt % (CRL_PEM, CA_PEM), **self.credentials) # type: ignore[arg-type] + ) + + @async_client_context.require_tlsCertificateKeyFile + @async_client_context.require_server_resolvable + @ignore_deprecations + async def test_validation_with_system_ca_certs(self): + # Expects the server to be running with server.pem and ca.pem. + # + # --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem + # --sslCAFile=/path/to/pymongo/test/certificates/ca.pem + # --sslWeakCertificateValidation + # + self.patch_system_certs(CA_PEM) + with self.assertRaises(ConnectionFailure): + # Server cert is verified but hostname matching fails + await connected( + self.simple_client( + "server", ssl=True, serverSelectionTimeoutMS=1000, **self.credentials + ) # type: ignore[arg-type] + ) + + # Server cert is verified. Disable hostname matching. + await connected( + self.simple_client( + "server", + ssl=True, + tlsAllowInvalidHostnames=True, + serverSelectionTimeoutMS=1000, + **self.credentials, # type: ignore[arg-type] + ) + ) + + # Server cert and hostname are verified. + await connected( + self.simple_client( + "localhost", ssl=True, serverSelectionTimeoutMS=1000, **self.credentials + ) # type: ignore[arg-type] + ) + + # Server cert and hostname are verified. + await connected( + self.simple_client( + "mongodb://localhost/?ssl=true&serverSelectionTimeoutMS=1000", + **self.credentials, # type: ignore[arg-type] + ) + ) + + def test_system_certs_config_error(self): + ctx = get_ssl_context(None, None, None, None, True, True, False, _IS_SYNC) + if (sys.platform != "win32" and hasattr(ctx, "set_default_verify_paths")) or hasattr( + ctx, "load_default_certs" + ): + raise SkipTest("Can't test when system CA certificates are loadable.") + + have_certifi = ssl_support.HAVE_CERTIFI + have_wincertstore = ssl_support.HAVE_WINCERTSTORE + # Force the test regardless of environment. + ssl_support.HAVE_CERTIFI = False + ssl_support.HAVE_WINCERTSTORE = False + try: + with self.assertRaises(ConfigurationError): + self.simple_client("mongodb://localhost/?ssl=true") + finally: + ssl_support.HAVE_CERTIFI = have_certifi + ssl_support.HAVE_WINCERTSTORE = have_wincertstore + + def test_certifi_support(self): + if hasattr(ssl, "SSLContext"): + # SSLSocket doesn't provide ca_certs attribute on pythons + # with SSLContext and SSLContext provides no information + # about ca_certs. + raise SkipTest("Can't test when SSLContext available.") + if not ssl_support.HAVE_CERTIFI: + raise SkipTest("Need certifi to test certifi support.") + + have_wincertstore = ssl_support.HAVE_WINCERTSTORE + # Force the test on Windows, regardless of environment. + ssl_support.HAVE_WINCERTSTORE = False + try: + ctx = get_ssl_context(None, None, CA_PEM, None, False, False, False, _IS_SYNC) + ssl_sock = ctx.wrap_socket(socket.socket()) + self.assertEqual(ssl_sock.ca_certs, CA_PEM) + + ctx = get_ssl_context(None, None, None, None, False, False, False, _IS_SYNC) + ssl_sock = ctx.wrap_socket(socket.socket()) + self.assertEqual(ssl_sock.ca_certs, ssl_support.certifi.where()) + finally: + ssl_support.HAVE_WINCERTSTORE = have_wincertstore + + def test_wincertstore(self): + if sys.platform != "win32": + raise SkipTest("Only valid on Windows.") + if hasattr(ssl, "SSLContext"): + # SSLSocket doesn't provide ca_certs attribute on pythons + # with SSLContext and SSLContext provides no information + # about ca_certs. + raise SkipTest("Can't test when SSLContext available.") + if not ssl_support.HAVE_WINCERTSTORE: + raise SkipTest("Need wincertstore to test wincertstore.") + + ctx = get_ssl_context(None, None, CA_PEM, None, False, False, False, _IS_SYNC) + ssl_sock = ctx.wrap_socket(socket.socket()) + self.assertEqual(ssl_sock.ca_certs, CA_PEM) + + ctx = get_ssl_context(None, None, None, None, False, False, False, _IS_SYNC) + ssl_sock = ctx.wrap_socket(socket.socket()) + self.assertEqual(ssl_sock.ca_certs, ssl_support._WINCERTS.name) + + @async_client_context.require_auth + @async_client_context.require_tlsCertificateKeyFile + @ignore_deprecations + async def test_mongodb_x509_auth(self): + host, port = await async_client_context.host, await async_client_context.port + self.addAsyncCleanup(remove_all_users, async_client_context.client["$external"]) + + # Give x509 user all necessary privileges. + await async_client_context.create_user( + "$external", + MONGODB_X509_USERNAME, + roles=[ + {"role": "readWriteAnyDatabase", "db": "admin"}, + {"role": "userAdminAnyDatabase", "db": "admin"}, + ], + ) + + noauth = self.simple_client( + await async_client_context.pair, + ssl=True, + tlsAllowInvalidCertificates=True, + tlsCertificateKeyFile=CLIENT_PEM, + ) + + with self.assertRaises(OperationFailure): + await noauth.pymongo_test.test.find_one() + + listener = EventListener() + auth = self.simple_client( + await async_client_context.pair, + authMechanism="MONGODB-X509", + ssl=True, + tlsAllowInvalidCertificates=True, + tlsCertificateKeyFile=CLIENT_PEM, + event_listeners=[listener], + ) + + # No error + await auth.pymongo_test.test.find_one() + names = listener.started_command_names() + if async_client_context.version.at_least(4, 4, -1): + # Speculative auth skips the authenticate command. + self.assertEqual(names, ["find"]) + else: + self.assertEqual(names, ["authenticate", "find"]) + + uri = "mongodb://%s@%s:%d/?authMechanism=MONGODB-X509" % ( + quote_plus(MONGODB_X509_USERNAME), + host, + port, + ) + client = self.simple_client( + uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM + ) + # No error + await client.pymongo_test.test.find_one() + + uri = "mongodb://%s:%d/?authMechanism=MONGODB-X509" % (host, port) + client = self.simple_client( + uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM + ) + # No error + await client.pymongo_test.test.find_one() + # Auth should fail if username and certificate do not match + uri = "mongodb://%s@%s:%d/?authMechanism=MONGODB-X509" % ( + quote_plus("not the username"), + host, + port, + ) + + bad_client = self.simple_client( + uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM + ) + + with self.assertRaises(OperationFailure): + await bad_client.pymongo_test.test.find_one() + + bad_client = self.simple_client( + await async_client_context.pair, + username="not the username", + authMechanism="MONGODB-X509", + ssl=True, + tlsAllowInvalidCertificates=True, + tlsCertificateKeyFile=CLIENT_PEM, + ) + + with self.assertRaises(OperationFailure): + await bad_client.pymongo_test.test.find_one() + + # Invalid certificate (using CA certificate as client certificate) + uri = "mongodb://%s@%s:%d/?authMechanism=MONGODB-X509" % ( + quote_plus(MONGODB_X509_USERNAME), + host, + port, + ) + try: + await connected( + self.simple_client( + uri, + ssl=True, + tlsAllowInvalidCertificates=True, + tlsCertificateKeyFile=CA_PEM, + serverSelectionTimeoutMS=1000, + ) + ) + except (ConnectionFailure, ConfigurationError): + pass + else: + self.fail("Invalid certificate accepted.") + + @async_client_context.require_tlsCertificateKeyFile + @ignore_deprecations + async def test_connect_with_ca_bundle(self): + def remove(path): + try: + os.remove(path) + except OSError: + pass + + temp_ca_bundle = os.path.join(CERT_PATH, "trusted-ca-bundle.pem") + self.addCleanup(remove, temp_ca_bundle) + # Add the CA cert file to the bundle. + cat_files(temp_ca_bundle, CA_BUNDLE_PEM, CA_PEM) + async with self.simple_client( + "localhost", tls=True, tlsCertificateKeyFile=CLIENT_PEM, tlsCAFile=temp_ca_bundle + ) as client: + self.assertTrue(await client.admin.command("ping")) + + @async_client_context.require_async + @unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.") + @unittest.skipUnless(HAVE_SSL, "The ssl module is not available.") + async def test_pyopenssl_ignored_in_async(self): + client = AsyncMongoClient( + "mongodb://localhost:27017?tls=true&tlsAllowInvalidCertificates=true" + ) + await client.admin.command("ping") # command doesn't matter, just needs it to connect + await client.close() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_streaming_protocol.py b/test/asynchronous/test_streaming_protocol.py new file mode 100644 index 0000000000..1206e7b2fa --- /dev/null +++ b/test/asynchronous/test_streaming_protocol.py @@ -0,0 +1,228 @@ +# Copyright 2020-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the database module.""" +from __future__ import annotations + +import sys +import time + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.utils_shared import ( + HeartbeatEventListener, + ServerEventListener, + async_wait_until, +) + +from pymongo import monitoring +from pymongo.hello import HelloCompat + +_IS_SYNC = False + + +class TestStreamingProtocol(AsyncIntegrationTest): + @async_client_context.require_failCommand_appName + async def test_failCommand_streaming(self): + listener = ServerEventListener() + hb_listener = HeartbeatEventListener() + client = await self.async_rs_or_single_client( + event_listeners=[listener, hb_listener], + heartbeatFrequencyMS=500, + appName="failingHeartbeatTest", + ) + # Force a connection. + await client.admin.command("ping") + address = await client.address + listener.reset() + + fail_hello = { + "configureFailPoint": "failCommand", + "mode": {"times": 4}, + "data": { + "failCommands": [HelloCompat.LEGACY_CMD, "hello"], + "closeConnection": False, + "errorCode": 10107, + "appName": "failingHeartbeatTest", + }, + } + async with self.fail_point(fail_hello): + + def _marked_unknown(event): + return ( + event.server_address == address + and not event.new_description.is_server_type_known + ) + + def _discovered_node(event): + return ( + event.server_address == address + and not event.previous_description.is_server_type_known + and event.new_description.is_server_type_known + ) + + def marked_unknown(): + return len(listener.matching(_marked_unknown)) >= 1 + + def rediscovered(): + return len(listener.matching(_discovered_node)) >= 1 + + # Topology events are not published synchronously + await async_wait_until(marked_unknown, "mark node unknown") + await async_wait_until(rediscovered, "rediscover node") + + # Server should be selectable. + await client.admin.command("ping") + + @async_client_context.require_failCommand_appName + async def test_streaming_rtt(self): + listener = ServerEventListener() + hb_listener = HeartbeatEventListener() + # On Windows, RTT can actually be 0.0 because time.time() only has + # 1-15 millisecond resolution. We need to delay the initial hello + # to ensure that RTT is never zero. + name = "streamingRttTest" + delay_hello: dict = { + "configureFailPoint": "failCommand", + "mode": {"times": 1000}, + "data": { + "failCommands": [HelloCompat.LEGACY_CMD, "hello"], + "blockConnection": True, + "blockTimeMS": 20, + # This can be uncommented after SERVER-49220 is fixed. + # 'appName': name, + }, + } + async with self.fail_point(delay_hello): + client = await self.async_rs_or_single_client( + event_listeners=[listener, hb_listener], heartbeatFrequencyMS=500, appName=name + ) + # Force a connection. + await client.admin.command("ping") + address = await client.address + + delay_hello["data"]["blockTimeMS"] = 500 + delay_hello["data"]["appName"] = name + async with self.fail_point(delay_hello): + + def rtt_exceeds_250_ms(): + # XXX: Add a public TopologyDescription getter to MongoClient? + topology = client._topology + sd = topology.description.server_descriptions()[address] + assert sd.round_trip_time is not None + return sd.round_trip_time > 0.250 + + await async_wait_until(rtt_exceeds_250_ms, "exceed 250ms RTT") + + # Server should be selectable. + await client.admin.command("ping") + + def changed_event(event): + return event.server_address == address and isinstance( + event, monitoring.ServerDescriptionChangedEvent + ) + + # There should only be one event published, for the initial discovery. + events = listener.matching(changed_event) + self.assertEqual(1, len(events)) + self.assertGreater(events[0].new_description.round_trip_time, 0) + + @async_client_context.require_failCommand_appName + async def test_monitor_waits_after_server_check_error(self): + # This test implements: + # https://github.com/mongodb/specifications/blob/master/source/server-discovery-and-monitoring/server-discovery-and-monitoring-tests.md#monitors-sleep-at-least-minheartbeatfreqencyms-between-checks + fail_hello = { + "mode": {"times": 5}, + "data": { + "failCommands": [HelloCompat.LEGACY_CMD, "hello"], + "errorCode": 1234, + "appName": "SDAMMinHeartbeatFrequencyTest", + }, + } + async with self.fail_point(fail_hello): + start = time.time() + client = await self.async_single_client( + appName="SDAMMinHeartbeatFrequencyTest", serverSelectionTimeoutMS=5000 + ) + # Force a connection. + await client.admin.command("ping") + duration = time.time() - start + # Explanation of the expected events: + # 0ms: run configureFailPoint + # 1ms: create MongoClient + # 2ms: failed monitor handshake, 1 + # 502ms: failed monitor handshake, 2 + # 1002ms: failed monitor handshake, 3 + # 1502ms: failed monitor handshake, 4 + # 2002ms: failed monitor handshake, 5 + # 2502ms: monitor handshake succeeds + # 2503ms: run awaitable hello + # 2504ms: application handshake succeeds + # 2505ms: ping command succeeds + self.assertGreaterEqual(duration, 2) + self.assertLessEqual(duration, 3.5) + + @async_client_context.require_failCommand_appName + async def test_heartbeat_awaited_flag(self): + hb_listener = HeartbeatEventListener() + client = await self.async_single_client( + event_listeners=[hb_listener], + heartbeatFrequencyMS=500, + appName="heartbeatEventAwaitedFlag", + ) + # Force a connection. + await client.admin.command("ping") + + def hb_succeeded(event): + return isinstance(event, monitoring.ServerHeartbeatSucceededEvent) + + def hb_failed(event): + return isinstance(event, monitoring.ServerHeartbeatFailedEvent) + + fail_heartbeat = { + "mode": {"times": 2}, + "data": { + "failCommands": [HelloCompat.LEGACY_CMD, "hello"], + "closeConnection": True, + "appName": "heartbeatEventAwaitedFlag", + }, + } + async with self.fail_point(fail_heartbeat): + await async_wait_until( + lambda: hb_listener.matching(hb_failed), "published failed event" + ) + # Reconnect. + await client.admin.command("ping") + + hb_succeeded_events = hb_listener.matching(hb_succeeded) + hb_failed_events = hb_listener.matching(hb_failed) + self.assertFalse(hb_succeeded_events[0].awaited) + self.assertTrue(hb_failed_events[0].awaited) + # Depending on thread scheduling, the failed heartbeat could occur on + # the second or third check. + events = [type(e) for e in hb_listener.events[:4]] + if events == [ + monitoring.ServerHeartbeatStartedEvent, + monitoring.ServerHeartbeatSucceededEvent, + monitoring.ServerHeartbeatStartedEvent, + monitoring.ServerHeartbeatFailedEvent, + ]: + self.assertFalse(hb_succeeded_events[1].awaited) + else: + self.assertTrue(hb_succeeded_events[1].awaited) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_transactions.py b/test/asynchronous/test_transactions.py index d11d0a9776..f151755217 100644 --- a/test/asynchronous/test_transactions.py +++ b/test/asynchronous/test_transactions.py @@ -24,7 +24,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.utils import ( +from test.utils_shared import ( OvertCommandListener, async_wait_until, ) @@ -32,7 +32,7 @@ from bson import encode from bson.raw_bson import RawBSONDocument -from pymongo import WriteConcern +from pymongo import WriteConcern, _csot from pymongo.asynchronous import client_session from pymongo.asynchronous.client_session import TransactionOptions from pymongo.asynchronous.command_cursor import AsyncCommandCursor @@ -295,6 +295,14 @@ async def gridfs_open_upload_stream(*args, **kwargs): "new-name", ), ), + ( + bucket.rename_by_name, + ( + "new-name", + "new-name2", + ), + ), + (bucket.delete_by_name, ("new-name2",)), ] async with client.start_session() as s, await s.start_transaction(): @@ -410,15 +418,10 @@ async def asyncSetUp(self) -> None: for address in async_client_context.mongoses: self.mongos_clients.append(await self.async_single_client("{}:{}".format(*address))) - async def _set_fail_point(self, client, command_args): - cmd = {"configureFailPoint": "failCommand"} - cmd.update(command_args) - await client.admin.command(cmd) - async def set_fail_point(self, command_args): clients = self.mongos_clients if self.mongos_clients else [self.client] for client in clients: - await self._set_fail_point(client, command_args) + await self.configure_fail_point(client, command_args) @async_client_context.require_transactions async def test_callback_raises_custom_error(self): @@ -583,5 +586,29 @@ async def callback(session): self.assertFalse(s.in_transaction) +class TestOptionsInsideTransactionProse(AsyncTransactionsBase): + @async_client_context.require_transactions + @async_client_context.require_no_standalone + async def test_case_1(self): + # Write concern not inherited from collection object inside transaction + # Create a MongoClient running against a configured sharded/replica set/load balanced cluster. + client = async_client_context.client + coll = client[self.db.name].test + await coll.delete_many({}) + # Start a new session on the client. + async with client.start_session() as s: + # Start a transaction on the session. + await s.start_transaction() + # Instantiate a collection object in the driver with a default write concern of { w: 0 }. + inner_coll = coll.with_options(write_concern=WriteConcern(w=0)) + # Insert the document { n: 1 } on the instantiated collection. + result = await inner_coll.insert_one({"n": 1}, session=s) + # Commit the transaction. + await s.commit_transaction() + # End the session. + # Ensure the document was inserted and no error was thrown from the transaction. + assert result.inserted_id is not None + + if __name__ == "__main__": unittest.main() diff --git a/test/asynchronous/test_transactions_unified.py b/test/asynchronous/test_transactions_unified.py new file mode 100644 index 0000000000..4519a0e39a --- /dev/null +++ b/test/asynchronous/test_transactions_unified.py @@ -0,0 +1,56 @@ +# Copyright 2021-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the Transactions unified spec tests.""" +from __future__ import annotations + +import os +import sys +from pathlib import Path + +sys.path[0:0] = [""] + +from test import client_context, unittest +from test.asynchronous.unified_format import generate_test_classes + +_IS_SYNC = False + + +@client_context.require_no_mmap +def setUpModule(): + pass + + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "transactions/unified") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "transactions/unified") + +# Generate unified tests. +globals().update(generate_test_classes(TEST_PATH, module=__name__)) + +# Location of JSON test specifications for transactions-convenient-api. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "transactions-convenient-api/unified") +else: + TEST_PATH = os.path.join( + Path(__file__).resolve().parent.parent, "transactions-convenient-api/unified" + ) + +# Generate unified tests. +globals().update(generate_test_classes(TEST_PATH, module=__name__)) + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_unified_format.py b/test/asynchronous/test_unified_format.py new file mode 100644 index 0000000000..a005739e95 --- /dev/null +++ b/test/asynchronous/test_unified_format.py @@ -0,0 +1,99 @@ +# Copyright 2020-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import os +import sys +from pathlib import Path +from typing import Any + +sys.path[0:0] = [""] + +from test import UnitTest, unittest +from test.asynchronous.unified_format import MatchEvaluatorUtil, generate_test_classes + +from bson import ObjectId + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "unified-test-format") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "unified-test-format") + + +globals().update( + generate_test_classes( + os.path.join(TEST_PATH, "valid-pass"), + module=__name__, + class_name_prefix="UnifiedTestFormat", + expected_failures=[ + "Client side error in command starting transaction", # PYTHON-1894 + ], + RUN_ON_SERVERLESS=False, + ) +) + + +globals().update( + generate_test_classes( + os.path.join(TEST_PATH, "valid-fail"), + module=__name__, + class_name_prefix="UnifiedTestFormat", + bypass_test_generation_errors=True, + expected_failures=[ + ".*", # All tests expected to fail + ], + RUN_ON_SERVERLESS=False, + ) +) + + +class TestMatchEvaluatorUtil(UnitTest): + def setUp(self): + self.match_evaluator = MatchEvaluatorUtil(self) + + def test_unsetOrMatches(self): + spec: dict[str, Any] = {"$$unsetOrMatches": {"y": {"$$unsetOrMatches": 2}}} + for actual in [{}, {"y": 2}, None]: + self.match_evaluator.match_result(spec, actual) + + spec = {"x": {"$$unsetOrMatches": {"y": {"$$unsetOrMatches": 2}}}} + for actual in [{}, {"x": {}}, {"x": {"y": 2}}]: + self.match_evaluator.match_result(spec, actual) + + spec = {"y": {"$$unsetOrMatches": {"$$exists": True}}} + self.match_evaluator.match_result(spec, {}) + self.match_evaluator.match_result(spec, {"y": 2}) + self.match_evaluator.match_result(spec, {"x": 1}) + self.match_evaluator.match_result(spec, {"y": {}}) + + def test_type(self): + self.match_evaluator.match_result( + { + "operationType": "insert", + "ns": {"db": "change-stream-tests", "coll": "test"}, + "fullDocument": {"_id": {"$$type": "objectId"}, "x": 1}, + }, + { + "operationType": "insert", + "fullDocument": {"_id": ObjectId("5fc93511ac93941052098f0c"), "x": 1}, + "ns": {"db": "change-stream-tests", "coll": "test"}, + }, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_versioned_api_integration.py b/test/asynchronous/test_versioned_api_integration.py new file mode 100644 index 0000000000..46e62d5c14 --- /dev/null +++ b/test/asynchronous/test_versioned_api_integration.py @@ -0,0 +1,86 @@ +# Copyright 2020-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import os +import sys +from pathlib import Path +from test.asynchronous.unified_format import generate_test_classes + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.utils_shared import OvertCommandListener + +from pymongo.server_api import ServerApi + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "versioned-api") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "versioned-api") + + +# Generate unified tests. +globals().update(generate_test_classes(TEST_PATH, module=__name__)) + + +class TestServerApiIntegration(AsyncIntegrationTest): + RUN_ON_LOAD_BALANCER = True + RUN_ON_SERVERLESS = True + + def assertServerApi(self, event): + self.assertIn("apiVersion", event.command) + self.assertEqual(event.command["apiVersion"], "1") + + def assertServerApiInAllCommands(self, events): + for event in events: + self.assertServerApi(event) + + @async_client_context.require_version_min(4, 7) + async def test_command_options(self): + listener = OvertCommandListener() + client = await self.async_rs_or_single_client( + server_api=ServerApi("1"), event_listeners=[listener] + ) + coll = client.test.test + await coll.insert_many([{} for _ in range(100)]) + self.addAsyncCleanup(coll.delete_many, {}) + await coll.find(batch_size=25).to_list() + await client.admin.command("ping") + self.assertServerApiInAllCommands(listener.started_events) + + @async_client_context.require_version_min(4, 7) + @async_client_context.require_transactions + async def test_command_options_txn(self): + listener = OvertCommandListener() + client = await self.async_rs_or_single_client( + server_api=ServerApi("1"), event_listeners=[listener] + ) + coll = client.test.test + await coll.insert_many([{} for _ in range(100)]) + self.addAsyncCleanup(coll.delete_many, {}) + + listener.reset() + async with client.start_session() as s, await s.start_transaction(): + await coll.insert_many([{} for _ in range(100)], session=s) + await coll.find(batch_size=25, session=s).to_list() + await client.test.command("find", "test", session=s) + self.assertServerApiInAllCommands(listener.started_events) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index 52d964eb3e..9099efbf0f 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -35,6 +35,8 @@ client_knobs, unittest, ) +from test.asynchronous.utils import async_get_pool +from test.asynchronous.utils_spec_runner import SpecRunnerTask from test.unified_format_shared import ( KMS_TLS_OPTS, PLACEHOLDER_MAP, @@ -48,8 +50,7 @@ parse_collection_or_database_options, with_metaclass, ) -from test.utils import ( - async_get_pool, +from test.utils_shared import ( async_wait_until, camel_to_snake, camel_to_snake_args, @@ -58,7 +59,6 @@ snake_to_camel, wait_until, ) -from test.utils_spec_runner import SpecRunnerThread from test.version import Version from typing import Any, Dict, List, Mapping, Optional @@ -66,7 +66,7 @@ from bson import SON, json_util from bson.codec_options import DEFAULT_CODEC_OPTIONS from bson.objectid import ObjectId -from gridfs import AsyncGridFSBucket, GridOut +from gridfs import AsyncGridFSBucket, GridOut, NoFile from pymongo import ASCENDING, AsyncMongoClient, CursorType, _csot from pymongo.asynchronous.change_stream import AsyncChangeStream from pymongo.asynchronous.client_session import AsyncClientSession, TransactionOptions, _TxnState @@ -222,7 +222,6 @@ def __init__(self, test_class): self._listeners: Dict[str, EventListenerUtil] = {} self._session_lsids: Dict[str, Mapping[str, Any]] = {} self.test: UnifiedSpecTestMixinV1 = test_class - self._cluster_time: Mapping[str, Any] = {} def __contains__(self, item): return item in self._entities @@ -378,12 +377,14 @@ async def drop(self: AsyncGridFSBucket, *args: Any, **kwargs: Any) -> None: opts["key_vault_client"], DEFAULT_CODEC_OPTIONS, opts.get("kms_tls_options", kms_tls_options), + opts.get("key_expiration_ms"), ) return elif entity_type == "thread": name = spec["id"] - thread = SpecRunnerThread(name) - thread.start() + thread = SpecRunnerTask(name) + await thread.start() + self.test.addAsyncCleanup(thread.join, 5) self[name] = thread return @@ -419,13 +420,11 @@ def get_lsid_for_session(self, session_name): # session has been closed. return self._session_lsids[session_name] - async def advance_cluster_times(self) -> None: + async def advance_cluster_times(self, cluster_time) -> None: """Manually synchronize entities when desired""" - if not self._cluster_time: - self._cluster_time = (await self.test.client.admin.command("ping")).get("$clusterTime") for entity in self._entities.values(): - if isinstance(entity, AsyncClientSession) and self._cluster_time: - entity.advance_cluster_time(self._cluster_time) + if isinstance(entity, AsyncClientSession) and cluster_time: + entity.advance_cluster_time(cluster_time) class UnifiedSpecTestMixinV1(AsyncIntegrationTest): @@ -438,7 +437,7 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest): a class attribute ``TEST_SPEC``. """ - SCHEMA_VERSION = Version.from_string("1.21") + SCHEMA_VERSION = Version.from_string("1.22") RUN_ON_LOAD_BALANCER = True RUN_ON_SERVERLESS = True TEST_SPEC: Any @@ -544,6 +543,14 @@ def maybe_skip_test(self, spec): self.skipTest("Implement PYTHON-1894") if "timeoutMS applied to entire download" in spec["description"]: self.skipTest("PyMongo's open_download_stream does not cap the stream's lifetime") + if ( + "Error returned from connection pool clear with interruptInUseConnections=true is retryable" + in spec["description"] + and not _IS_SYNC + ): + self.skipTest("PYTHON-5170 tests are flakey") + if "Driver extends timeout while streaming" in spec["description"] and not _IS_SYNC: + self.skipTest("PYTHON-5174 tests are flakey") class_name = self.__class__.__name__.lower() description = spec["description"].lower() @@ -558,7 +565,11 @@ def maybe_skip_test(self, spec): self.skipTest("CSOT not implemented for watch()") if "cursors" in class_name: self.skipTest("CSOT not implemented for cursors") - if "tailable" in class_name: + if ( + "tailable" in class_name + or "tailable" in description + and "non-tailable" not in description + ): self.skipTest("CSOT not implemented for tailable cursors") if "sessions" in class_name: self.skipTest("CSOT not implemented for sessions") @@ -618,7 +629,7 @@ def process_error(self, exception, spec): # Connection errors are considered client errors. if isinstance(error, ConnectionFailure): self.assertNotIsInstance(error, NotPrimaryError) - elif isinstance(error, (InvalidOperation, ConfigurationError, EncryptionError)): + elif isinstance(error, (InvalidOperation, ConfigurationError, EncryptionError, NoFile)): pass else: self.assertNotIsInstance(error, PyMongoError) @@ -711,7 +722,7 @@ async def _databaseOperation_runCommand(self, target, **kwargs): return await target.command(**kwargs) async def _databaseOperation_runCursorCommand(self, target, **kwargs): - return list(await self._databaseOperation_createCommandCursor(target, **kwargs)) + return await (await self._databaseOperation_createCommandCursor(target, **kwargs)).to_list() async def _databaseOperation_createCommandCursor(self, target, **kwargs): self.__raise_if_unsupported("createCommandCursor", target, AsyncDatabase) @@ -1008,12 +1019,8 @@ async def __set_fail_point(self, client, command_args): if not async_client_context.test_commands_enabled: self.skipTest("Test commands must be enabled") - cmd_on = SON([("configureFailPoint", "failCommand")]) - cmd_on.update(command_args) - await client.admin.command(cmd_on) - self.addAsyncCleanup( - client.admin.command, "configureFailPoint", cmd_on["configureFailPoint"], mode="off" - ) + await self.configure_fail_point(client, command_args) + self.addAsyncCleanup(self.configure_fail_point, client, command_args, off=True) async def _testOperation_failPoint(self, spec): await self.__set_fail_point( @@ -1034,7 +1041,7 @@ async def _testOperation_targetedFailPoint(self, spec): async def _testOperation_createEntities(self, spec): await self.entity_map.create_entities_from_spec(spec["entities"], uri=self._uri) - await self.entity_map.advance_cluster_times() + await self.entity_map.advance_cluster_times(self._cluster_time) def _testOperation_assertSessionTransactionState(self, spec): session = self.entity_map[spec["session"]] @@ -1155,7 +1162,7 @@ def _testOperation_assertTopologyType(self, spec): self.assertIsInstance(description, TopologyDescription) self.assertEqual(description.topology_type_name, spec["topologyType"]) - def _testOperation_waitForPrimaryChange(self, spec: dict) -> None: + async def _testOperation_waitForPrimaryChange(self, spec: dict) -> None: """Run the waitForPrimaryChange test operation.""" client = self.entity_map[spec["client"]] old_description: TopologyDescription = self.entity_map[spec["priorTopologyDescription"]] @@ -1169,24 +1176,24 @@ def get_primary(td: TopologyDescription) -> Optional[_Address]: old_primary = get_primary(old_description) - def primary_changed() -> bool: - primary = client.primary + async def primary_changed() -> bool: + primary = await client.primary if primary is None: return False return primary != old_primary - wait_until(primary_changed, "change primary", timeout=timeout) + await async_wait_until(primary_changed, "change primary", timeout=timeout) - def _testOperation_runOnThread(self, spec): + async def _testOperation_runOnThread(self, spec): """Run the 'runOnThread' operation.""" thread = self.entity_map[spec["thread"]] - thread.schedule(lambda: self.run_entity_operation(spec["operation"])) + await thread.schedule(functools.partial(self.run_entity_operation, spec["operation"])) - def _testOperation_waitForThread(self, spec): + async def _testOperation_waitForThread(self, spec): """Run the 'waitForThread' operation.""" thread = self.entity_map[spec["thread"]] - thread.stop() - thread.join(10) + await thread.stop() + await thread.join(10) if thread.exc: raise thread.exc self.assertFalse(thread.is_alive(), "Thread {} is still running".format(spec["thread"])) @@ -1387,7 +1394,6 @@ async def run_scenario(self, spec, uri=None): # transaction (from a test failure) from blocking collection/database # operations during test set up and tear down. await self.kill_all_sessions() - self.addAsyncCleanup(self.kill_all_sessions) if "csot" in self.id().lower(): # Retry CSOT tests up to 2 times to deal with flakey tests. @@ -1395,7 +1401,11 @@ async def run_scenario(self, spec, uri=None): for i in range(attempts): try: return await self._run_scenario(spec, uri) - except AssertionError: + except (AssertionError, OperationFailure) as exc: + if isinstance(exc, OperationFailure) and ( + _IS_SYNC or "failpoint" not in exc._message + ): + raise if i < attempts - 1: print( f"Retrying after attempt {i+1} of {self.id()} failed with:\n" @@ -1430,11 +1440,12 @@ async def _run_scenario(self, spec, uri=None): await self.entity_map.create_entities_from_spec( self.TEST_SPEC.get("createEntities", []), uri=uri ) + self._cluster_time = None # process initialData if "initialData" in self.TEST_SPEC: await self.insert_initial_data(self.TEST_SPEC["initialData"]) - self._cluster_time = (await self.client.admin.command("ping")).get("$clusterTime") - await self.entity_map.advance_cluster_times() + self._cluster_time = self.client._topology.max_cluster_time() + await self.entity_map.advance_cluster_times(self._cluster_time) if "expectLogMessages" in spec: expect_log_messages = spec["expectLogMessages"] diff --git a/test/asynchronous/utils.py b/test/asynchronous/utils.py new file mode 100644 index 0000000000..f653c575e9 --- /dev/null +++ b/test/asynchronous/utils.py @@ -0,0 +1,212 @@ +# Copyright 2012-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for testing pymongo that require synchronization.""" +from __future__ import annotations + +import asyncio +import contextlib +import random +import threading # Used in the synchronized version of this file +import time +from asyncio import iscoroutinefunction + +from bson.son import SON +from pymongo import AsyncMongoClient +from pymongo.errors import ConfigurationError +from pymongo.hello import HelloCompat +from pymongo.lock import _async_create_lock +from pymongo.operations import _Op +from pymongo.read_preferences import ReadPreference +from pymongo.server_selectors import any_server_selector, writable_server_selector +from pymongo.synchronous.pool import _CancellationContext, _PoolGeneration + +_IS_SYNC = False + + +async def async_get_pool(client): + """Get the standalone, primary, or mongos pool.""" + topology = await client._get_topology() + server = await topology._select_server(writable_server_selector, _Op.TEST) + return server.pool + + +async def async_get_pools(client): + """Get all pools.""" + return [ + server.pool + for server in await (await client._get_topology()).select_servers( + any_server_selector, _Op.TEST + ) + ] + + +async def async_wait_until(predicate, success_description, timeout=10): + """Wait up to 10 seconds (by default) for predicate to be true. + + E.g.: + + wait_until(lambda: client.primary == ('a', 1), + 'connect to the primary') + + If the lambda-expression isn't true after 10 seconds, we raise + AssertionError("Didn't ever connect to the primary"). + + Returns the predicate's first true value. + """ + start = time.time() + interval = min(float(timeout) / 100, 0.1) + while True: + if iscoroutinefunction(predicate): + retval = await predicate() + else: + retval = predicate() + if retval: + return retval + + if time.time() - start > timeout: + raise AssertionError("Didn't ever %s" % success_description) + + await asyncio.sleep(interval) + + +async def async_is_mongos(client): + res = await client.admin.command(HelloCompat.LEGACY_CMD) + return res.get("msg", "") == "isdbgrid" + + +async def async_ensure_all_connected(client: AsyncMongoClient) -> None: + """Ensure that the client's connection pool has socket connections to all + members of a replica set. Raises ConfigurationError when called with a + non-replica set client. + + Depending on the use-case, the caller may need to clear any event listeners + that are configured on the client. + """ + hello: dict = await client.admin.command(HelloCompat.LEGACY_CMD) + if "setName" not in hello: + raise ConfigurationError("cluster is not a replica set") + + target_host_list = set(hello["hosts"] + hello.get("passives", [])) + connected_host_list = {hello["me"]} + + # Run hello until we have connected to each host at least once. + async def discover(): + i = 0 + while i < 100 and connected_host_list != target_host_list: + hello: dict = await client.admin.command( + HelloCompat.LEGACY_CMD, read_preference=ReadPreference.SECONDARY + ) + connected_host_list.update([hello["me"]]) + i += 1 + return connected_host_list + + try: + + async def predicate(): + return target_host_list == await discover() + + await async_wait_until(predicate, "connected to all hosts") + except AssertionError as exc: + raise AssertionError( + f"{exc}, {connected_host_list} != {target_host_list}, {client.topology_description}" + ) + + +async def asyncAssertRaisesExactly(cls, fn, *args, **kwargs): + """ + Unlike the standard assertRaises, this checks that a function raises a + specific class of exception, and not a subclass. E.g., check that + MongoClient() raises ConnectionFailure but not its subclass, AutoReconnect. + """ + try: + await fn(*args, **kwargs) + except Exception as e: + assert e.__class__ == cls, f"got {e.__class__.__name__}, expected {cls.__name__}" + else: + raise AssertionError("%s not raised" % cls) + + +async def async_set_fail_point(client, command_args): + cmd = SON([("configureFailPoint", "failCommand")]) + cmd.update(command_args) + await client.admin.command(cmd) + + +async def async_joinall(tasks): + """Join threads with a 5-minute timeout, assert joins succeeded""" + if _IS_SYNC: + for t in tasks: + t.join(300) + assert not t.is_alive(), "Thread %s hung" % t + else: + await asyncio.wait([t.task for t in tasks if t is not None], timeout=300) + + +class AsyncMockConnection: + def __init__(self): + self.cancel_context = _CancellationContext() + self.more_to_come = False + self.id = random.randint(0, 100) + self.server_connection_id = random.randint(0, 100) + + def close_conn(self, reason): + pass + + def __aenter__(self): + return self + + def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + +class AsyncMockPool: + def __init__(self, address, options, handshake=True, client_id=None): + self.gen = _PoolGeneration() + self._lock = _async_create_lock() + self.opts = options + self.operation_count = 0 + self.conns = [] + + def stale_generation(self, gen, service_id): + return self.gen.stale(gen, service_id) + + @contextlib.asynccontextmanager + async def checkout(self, handler=None): + yield AsyncMockConnection() + + async def checkin(self, *args, **kwargs): + pass + + async def _reset(self, service_id=None): + async with self._lock: + self.gen.inc(service_id) + + async def ready(self): + pass + + async def reset(self, service_id=None, interrupt_connections=False): + await self._reset() + + async def reset_without_pause(self): + await self._reset() + + async def close(self): + await self._reset() + + async def update_is_writable(self, is_writable): + pass + + async def remove_stale_sockets(self, *args, **kwargs): + pass diff --git a/test/asynchronous/utils_selection_tests.py b/test/asynchronous/utils_selection_tests.py new file mode 100644 index 0000000000..d6b92fadb4 --- /dev/null +++ b/test/asynchronous/utils_selection_tests.py @@ -0,0 +1,204 @@ +# Copyright 2015-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for testing Server Selection and Max Staleness.""" +from __future__ import annotations + +import datetime +import os +import sys +from test.asynchronous import AsyncPyMongoTestCase +from test.asynchronous.utils import AsyncMockPool + +sys.path[0:0] = [""] + +from test import unittest +from test.pymongo_mocks import DummyMonitor +from test.utils_selection_tests_shared import ( + get_addresses, + get_topology_type_name, + make_server_description, +) +from test.utils_shared import parse_read_preference + +from bson import json_util +from pymongo.asynchronous.settings import TopologySettings +from pymongo.asynchronous.topology import Topology +from pymongo.common import HEARTBEAT_FREQUENCY +from pymongo.errors import AutoReconnect, ConfigurationError +from pymongo.operations import _Op +from pymongo.server_selectors import writable_server_selector + +_IS_SYNC = False + + +def get_topology_settings_dict(**kwargs): + settings = { + "monitor_class": DummyMonitor, + "heartbeat_frequency": HEARTBEAT_FREQUENCY, + "pool_class": AsyncMockPool, + } + settings.update(kwargs) + return settings + + +async def create_topology(scenario_def, **kwargs): + # Initialize topologies. + if "heartbeatFrequencyMS" in scenario_def: + frequency = int(scenario_def["heartbeatFrequencyMS"]) / 1000.0 + else: + frequency = HEARTBEAT_FREQUENCY + + seeds, hosts = get_addresses(scenario_def["topology_description"]["servers"]) + + topology_type = get_topology_type_name(scenario_def) + if topology_type == "LoadBalanced": + kwargs.setdefault("load_balanced", True) + # Force topology description to ReplicaSet + elif topology_type in ["ReplicaSetNoPrimary", "ReplicaSetWithPrimary"]: + kwargs.setdefault("replica_set_name", "rs") + settings = get_topology_settings_dict(heartbeat_frequency=frequency, seeds=seeds, **kwargs) + + # "Eligible servers" is defined in the server selection spec as + # the set of servers matching both the ReadPreference's mode + # and tag sets. + topology = Topology(TopologySettings(**settings)) + await topology.open() + + # Update topologies with server descriptions. + for server in scenario_def["topology_description"]["servers"]: + server_description = make_server_description(server, hosts) + await topology.on_change(server_description) + + # Assert that descriptions match + assert ( + scenario_def["topology_description"]["type"] == topology.description.topology_type_name + ), topology.description.topology_type_name + + return topology + + +def create_test(scenario_def): + async def run_scenario(self): + _, hosts = get_addresses(scenario_def["topology_description"]["servers"]) + # "Eligible servers" is defined in the server selection spec as + # the set of servers matching both the ReadPreference's mode + # and tag sets. + top_latency = await create_topology(scenario_def) + + # "In latency window" is defined in the server selection + # spec as the subset of suitable_servers that falls within the + # allowable latency window. + top_suitable = await create_topology(scenario_def, local_threshold_ms=1000000) + + # Create server selector. + if scenario_def.get("operation") == "write": + pref = writable_server_selector + else: + # Make first letter lowercase to match read_pref's modes. + pref_def = scenario_def["read_preference"] + if scenario_def.get("error"): + with self.assertRaises((ConfigurationError, ValueError)): + # Error can be raised when making Read Pref or selecting. + pref = parse_read_preference(pref_def) + await top_latency.select_server(pref, _Op.TEST) + return + + pref = parse_read_preference(pref_def) + + # Select servers. + if not scenario_def.get("suitable_servers"): + with self.assertRaises(AutoReconnect): + await top_suitable.select_server(pref, _Op.TEST, server_selection_timeout=0) + + return + + if not scenario_def["in_latency_window"]: + with self.assertRaises(AutoReconnect): + await top_latency.select_server(pref, _Op.TEST, server_selection_timeout=0) + + return + + actual_suitable_s = await top_suitable.select_servers( + pref, _Op.TEST, server_selection_timeout=0 + ) + actual_latency_s = await top_latency.select_servers( + pref, _Op.TEST, server_selection_timeout=0 + ) + + expected_suitable_servers = {} + for server in scenario_def["suitable_servers"]: + server_description = make_server_description(server, hosts) + expected_suitable_servers[server["address"]] = server_description + + actual_suitable_servers = {} + for s in actual_suitable_s: + actual_suitable_servers[ + "%s:%d" % (s.description.address[0], s.description.address[1]) + ] = s.description + + self.assertEqual(len(actual_suitable_servers), len(expected_suitable_servers)) + for k, actual in actual_suitable_servers.items(): + expected = expected_suitable_servers[k] + self.assertEqual(expected.address, actual.address) + self.assertEqual(expected.server_type, actual.server_type) + self.assertEqual(expected.round_trip_time, actual.round_trip_time) + self.assertEqual(expected.tags, actual.tags) + self.assertEqual(expected.all_hosts, actual.all_hosts) + + expected_latency_servers = {} + for server in scenario_def["in_latency_window"]: + server_description = make_server_description(server, hosts) + expected_latency_servers[server["address"]] = server_description + + actual_latency_servers = {} + for s in actual_latency_s: + actual_latency_servers[ + "%s:%d" % (s.description.address[0], s.description.address[1]) + ] = s.description + + self.assertEqual(len(actual_latency_servers), len(expected_latency_servers)) + for k, actual in actual_latency_servers.items(): + expected = expected_latency_servers[k] + self.assertEqual(expected.address, actual.address) + self.assertEqual(expected.server_type, actual.server_type) + self.assertEqual(expected.round_trip_time, actual.round_trip_time) + self.assertEqual(expected.tags, actual.tags) + self.assertEqual(expected.all_hosts, actual.all_hosts) + + return run_scenario + + +def create_selection_tests(test_dir): + class TestAllScenarios(AsyncPyMongoTestCase): + pass + + for dirpath, _, filenames in os.walk(test_dir): + dirname = os.path.split(dirpath) + dirname = os.path.split(dirname[-2])[-1] + "_" + dirname[-1] + + for filename in filenames: + if os.path.splitext(filename)[1] != ".json": + continue + with open(os.path.join(dirpath, filename)) as scenario_stream: + scenario_def = json_util.loads(scenario_stream.read()) + + # Construct test from scenario. + new_test = create_test(scenario_def) + test_name = f"test_{dirname}_{os.path.splitext(filename)[0]}" + + new_test.__name__ = test_name + setattr(TestAllScenarios, new_test.__name__, new_test) + + return TestAllScenarios diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index b79e5258b5..c83636a734 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -18,12 +18,13 @@ import asyncio import functools import os -import threading +import time import unittest from asyncio import iscoroutinefunction from collections import abc from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs -from test.utils import ( +from test.asynchronous.helpers import ConcurrentRunner +from test.utils_shared import ( CMAPListener, CompareType, EventListener, @@ -47,6 +48,7 @@ from pymongo.asynchronous.command_cursor import AsyncCommandCursor from pymongo.asynchronous.cursor import AsyncCursor from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError +from pymongo.lock import _async_cond_wait, _async_create_condition, _async_create_lock from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference from pymongo.results import BulkWriteResult, _WriteResult @@ -55,38 +57,36 @@ _IS_SYNC = False -class SpecRunnerThread(threading.Thread): +class SpecRunnerTask(ConcurrentRunner): def __init__(self, name): - super().__init__() - self.name = name + super().__init__(name=name) self.exc = None self.daemon = True - self.cond = threading.Condition() + self.cond = _async_create_condition(_async_create_lock()) self.ops = [] - self.stopped = False - def schedule(self, work): + async def schedule(self, work): self.ops.append(work) - with self.cond: + async with self.cond: self.cond.notify() - def stop(self): + async def stop(self): self.stopped = True - with self.cond: + async with self.cond: self.cond.notify() - def run(self): + async def run(self): while not self.stopped or self.ops: if not self.ops: - with self.cond: - self.cond.wait(10) + async with self.cond: + await _async_cond_wait(self.cond, 10) if self.ops: try: work = self.ops.pop(0) - work() + await work() except Exception as exc: self.exc = exc - self.stop() + await self.stop() class AsyncSpecTestCreator: @@ -230,7 +230,7 @@ async def _create_tests(self): str(test_def["description"].replace(" ", "_").replace(".", "_")), ) - new_test = await self._create_test(scenario_def, test_def, test_name) + new_test = self._create_test(scenario_def, test_def, test_name) new_test = self._ensure_min_max_server_version(scenario_def, new_test) new_test = self.ensure_run_on(scenario_def, new_test) @@ -265,15 +265,10 @@ async def asyncSetUp(self) -> None: async def asyncTearDown(self) -> None: self.knobs.disable() - async def _set_fail_point(self, client, command_args): - cmd = SON([("configureFailPoint", "failCommand")]) - cmd.update(command_args) - await client.admin.command(cmd) - async def set_fail_point(self, command_args): clients = self.mongos_clients if self.mongos_clients else [self.client] for client in clients: - await self._set_fail_point(client, command_args) + await self.configure_fail_point(client, command_args) async def targeted_fail_point(self, session, fail_point): """Run the targetedFailPoint test operation. @@ -282,7 +277,7 @@ async def targeted_fail_point(self, session, fail_point): """ clients = {c.address: c for c in self.mongos_clients} client = clients[session._pinned_address] - await self._set_fail_point(client, fail_point) + await self.configure_fail_point(client, fail_point) self.addAsyncCleanup(self.set_fail_point, {"mode": "off"}) def assert_session_pinned(self, session): @@ -320,6 +315,10 @@ async def assert_index_not_exists(self, database, collection, index): coll = self.client[database][collection] self.assertNotIn(index, [doc["name"] async for doc in await coll.list_indexes()]) + async def wait(self, ms): + """Run the "wait" test operation.""" + await asyncio.sleep(ms / 1000.0) + def assertErrorLabelsContain(self, exc, expected_labels): labels = [l for l in expected_labels if exc.has_error_label(l)] self.assertEqual(labels, expected_labels) diff --git a/test/atlas/test_connection.py b/test/atlas/test_connection.py index 4dcbba6d11..a3e8b0b1d5 100644 --- a/test/atlas/test_connection.py +++ b/test/atlas/test_connection.py @@ -26,9 +26,9 @@ sys.path[0:0] = [""] import pymongo -from pymongo.ssl_support import HAS_SNI +from pymongo.ssl_support import _has_sni -pytestmark = pytest.mark.atlas +pytestmark = pytest.mark.atlas_connect URIS = { @@ -57,7 +57,7 @@ def connect(self, uri): # No auth error client.test.test.count_documents({}) - @unittest.skipUnless(HAS_SNI, "Free tier requires SNI support") + @unittest.skipUnless(_has_sni(True), "Free tier requires SNI support") def test_free_tier(self): self.connect(URIS["ATLAS_FREE"]) @@ -80,7 +80,7 @@ def connect_srv(self, uri): self.connect(uri) self.assertIn("mongodb+srv://", uri) - @unittest.skipUnless(HAS_SNI, "Free tier requires SNI support") + @unittest.skipUnless(_has_sni(True), "Free tier requires SNI support") def test_srv_free_tier(self): self.connect_srv(URIS["ATLAS_SRV_FREE"]) diff --git a/test/auth_aws/test_auth_aws.py b/test/auth_aws/test_auth_aws.py index a7660f2f67..9738694d85 100644 --- a/test/auth_aws/test_auth_aws.py +++ b/test/auth_aws/test_auth_aws.py @@ -32,7 +32,7 @@ from pymongo import MongoClient from pymongo.errors import OperationFailure -from pymongo.uri_parser import parse_uri +from pymongo.synchronous.uri_parser import parse_uri pytestmark = pytest.mark.auth_aws diff --git a/test/auth_oidc/test_auth_oidc.py b/test/auth_oidc/test_auth_oidc.py index 7a78f3d2f6..7dbf817cce 100644 --- a/test/auth_oidc/test_auth_oidc.py +++ b/test/auth_oidc/test_auth_oidc.py @@ -31,7 +31,7 @@ sys.path[0:0] = [""] from test.unified_format import generate_test_classes -from test.utils import EventListener, OvertCommandListener +from test.utils_shared import EventListener, OvertCommandListener from bson import SON from pymongo import MongoClient @@ -49,7 +49,7 @@ OIDCCallbackResult, _get_authenticator, ) -from pymongo.uri_parser import parse_uri +from pymongo.synchronous.uri_parser import parse_uri ROOT = Path(__file__).parent.parent.resolve() TEST_PATH = ROOT / "auth" / "unified" @@ -70,6 +70,11 @@ def setUpClass(cls): cls.uri_single = os.environ["MONGODB_URI_SINGLE"] cls.uri_multiple = os.environ.get("MONGODB_URI_MULTI") cls.uri_admin = os.environ["MONGODB_URI"] + if ENVIRON == "test": + if not TOKEN_DIR: + raise ValueError("Please set OIDC_TOKEN_DIR") + if not TOKEN_FILE: + raise ValueError("Please set OIDC_TOKEN_FILE") def setUp(self): self.request_called = 0 @@ -237,9 +242,9 @@ def test_1_6_allowed_hosts_blocked(self): authmechanismproperties=props, connect=False, ) - # Assert that a find operation fails with a client-side error. - with self.assertRaises(ConfigurationError): - client.test.test.find_one() + # Assert that a find operation fails with a client-side error. + with self.assertRaises(ConfigurationError): + client.test.test.find_one() # Close the client. client.close() diff --git a/test/bson_binary_vector/float32.json b/test/bson_binary_vector/float32.json index bbbe00b758..72dafce10f 100644 --- a/test/bson_binary_vector/float32.json +++ b/test/bson_binary_vector/float32.json @@ -11,6 +11,15 @@ "padding": 0, "canonical_bson": "1C00000005766563746F72000A0000000927000000FE420000E04000" }, + { + "description": "Vector with decimals and negative value FLOAT32", + "valid": true, + "vector": [127.7, -7.7], + "dtype_hex": "0x27", + "dtype_alias": "FLOAT32", + "padding": 0, + "canonical_bson": "1C00000005766563746F72000A0000000927006666FF426666F6C000" + }, { "description": "Empty Vector FLOAT32", "valid": true, @@ -23,7 +32,7 @@ { "description": "Infinity Vector FLOAT32", "valid": true, - "vector": ["-inf", 0.0, "inf"], + "vector": [{"$numberDouble": "-Infinity"}, 0.0, {"$numberDouble": "Infinity"} ], "dtype_hex": "0x27", "dtype_alias": "FLOAT32", "padding": 0, @@ -35,8 +44,22 @@ "vector": [127.0, 7.0], "dtype_hex": "0x27", "dtype_alias": "FLOAT32", - "padding": 3 + "padding": 3, + "canonical_bson": "1C00000005766563746F72000A0000000927030000FE420000E04000" + }, + { + "description": "Insufficient vector data with 3 bytes FLOAT32", + "valid": false, + "dtype_hex": "0x27", + "dtype_alias": "FLOAT32", + "canonical_bson": "1700000005766563746F7200050000000927002A2A2A00" + }, + { + "description": "Insufficient vector data with 5 bytes FLOAT32", + "valid": false, + "dtype_hex": "0x27", + "dtype_alias": "FLOAT32", + "canonical_bson": "1900000005766563746F7200070000000927002A2A2A2A2A00" } ] } - diff --git a/test/bson_binary_vector/int8.json b/test/bson_binary_vector/int8.json index 7529721e5e..29524fb617 100644 --- a/test/bson_binary_vector/int8.json +++ b/test/bson_binary_vector/int8.json @@ -42,7 +42,8 @@ "vector": [127, 7], "dtype_hex": "0x03", "dtype_alias": "INT8", - "padding": 3 + "padding": 3, + "canonical_bson": "1600000005766563746F7200040000000903037F0700" }, { "description": "INT8 with float inputs", @@ -54,4 +55,3 @@ } ] } - diff --git a/test/bson_binary_vector/packed_bit.json b/test/bson_binary_vector/packed_bit.json index a41cd593f5..a220e7e318 100644 --- a/test/bson_binary_vector/packed_bit.json +++ b/test/bson_binary_vector/packed_bit.json @@ -2,6 +2,15 @@ "description": "Tests of Binary subtype 9, Vectors, with dtype PACKED_BIT", "test_key": "vector", "tests": [ + { + "description": "Padding specified with no vector data PACKED_BIT", + "valid": false, + "vector": [], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 1, + "canonical_bson": "1400000005766563746F72000200000009100100" + }, { "description": "Simple Vector PACKED_BIT", "valid": true, @@ -44,7 +53,31 @@ "dtype_hex": "0x10", "dtype_alias": "PACKED_BIT", "padding": 0 + }, + { + "description": "Vector with float values PACKED_BIT", + "valid": false, + "vector": [127.5], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 0 + }, + { + "description": "Exceeding maximum padding PACKED_BIT", + "valid": false, + "vector": [1], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": 8, + "canonical_bson": "1500000005766563746F7200030000000910080100" + }, + { + "description": "Negative padding PACKED_BIT", + "valid": false, + "vector": [1], + "dtype_hex": "0x10", + "dtype_alias": "PACKED_BIT", + "padding": -1 } ] } - diff --git a/test/change_streams/unified/change-streams-clusterTime.json b/test/change_streams/unified/change-streams-clusterTime.json index 55b4ae3fbc..2b09e548f1 100644 --- a/test/change_streams/unified/change-streams-clusterTime.json +++ b/test/change_streams/unified/change-streams-clusterTime.json @@ -28,7 +28,6 @@ "minServerVersion": "4.0.0", "topologies": [ "replicaset", - "sharded-replicaset", "load-balanced", "sharded" ], diff --git a/test/change_streams/unified/change-streams-disambiguatedPaths.json b/test/change_streams/unified/change-streams-disambiguatedPaths.json index 91d8e66da2..a8667b5436 100644 --- a/test/change_streams/unified/change-streams-disambiguatedPaths.json +++ b/test/change_streams/unified/change-streams-disambiguatedPaths.json @@ -28,7 +28,6 @@ "minServerVersion": "6.1.0", "topologies": [ "replicaset", - "sharded-replicaset", "load-balanced", "sharded" ], @@ -43,70 +42,6 @@ } ], "tests": [ - { - "description": "disambiguatedPaths is not present when showExpandedEvents is false/unset", - "operations": [ - { - "name": "insertOne", - "object": "collection0", - "arguments": { - "document": { - "_id": 1, - "a": { - "1": 1 - } - } - } - }, - { - "name": "createChangeStream", - "object": "collection0", - "arguments": { - "pipeline": [] - }, - "saveResultAsEntity": "changeStream0" - }, - { - "name": "updateOne", - "object": "collection0", - "arguments": { - "filter": { - "_id": 1 - }, - "update": { - "$set": { - "a.1": 2 - } - } - } - }, - { - "name": "iterateUntilDocumentOrError", - "object": "changeStream0", - "expectResult": { - "operationType": "update", - "ns": { - "db": "database0", - "coll": "collection0" - }, - "updateDescription": { - "updatedFields": { - "$$exists": true - }, - "removedFields": { - "$$exists": true - }, - "truncatedArrays": { - "$$exists": true - }, - "disambiguatedPaths": { - "$$exists": false - } - } - } - } - ] - }, { "description": "disambiguatedPaths is present on updateDescription when an ambiguous path is present", "operations": [ diff --git a/test/change_streams/unified/change-streams-errors.json b/test/change_streams/unified/change-streams-errors.json index 04fe8f04f3..65e99e541e 100644 --- a/test/change_streams/unified/change-streams-errors.json +++ b/test/change_streams/unified/change-streams-errors.json @@ -145,7 +145,7 @@ "minServerVersion": "4.1.11", "topologies": [ "replicaset", - "sharded-replicaset", + "sharded", "load-balanced" ] } @@ -190,7 +190,7 @@ "minServerVersion": "4.2", "topologies": [ "replicaset", - "sharded-replicaset", + "sharded", "load-balanced" ] } diff --git a/test/change_streams/unified/change-streams-nsType.json b/test/change_streams/unified/change-streams-nsType.json new file mode 100644 index 0000000000..1861c9a5e0 --- /dev/null +++ b/test/change_streams/unified/change-streams-nsType.json @@ -0,0 +1,145 @@ +{ + "description": "change-streams-nsType", + "schemaVersion": "1.7", + "runOnRequirements": [ + { + "minServerVersion": "8.1.0", + "topologies": [ + "replicaset", + "sharded" + ], + "serverless": "forbid" + } + ], + "createEntities": [ + { + "client": { + "id": "client0", + "useMultipleMongoses": false + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "database0" + } + } + ], + "tests": [ + { + "description": "nsType is present when creating collections", + "operations": [ + { + "name": "dropCollection", + "object": "database0", + "arguments": { + "collection": "foo" + } + }, + { + "name": "createChangeStream", + "object": "database0", + "arguments": { + "pipeline": [], + "showExpandedEvents": true + }, + "saveResultAsEntity": "changeStream0" + }, + { + "name": "createCollection", + "object": "database0", + "arguments": { + "collection": "foo" + } + }, + { + "name": "iterateUntilDocumentOrError", + "object": "changeStream0", + "expectResult": { + "operationType": "create", + "nsType": "collection" + } + } + ] + }, + { + "description": "nsType is present when creating timeseries", + "operations": [ + { + "name": "dropCollection", + "object": "database0", + "arguments": { + "collection": "foo" + } + }, + { + "name": "createChangeStream", + "object": "database0", + "arguments": { + "pipeline": [], + "showExpandedEvents": true + }, + "saveResultAsEntity": "changeStream0" + }, + { + "name": "createCollection", + "object": "database0", + "arguments": { + "collection": "foo", + "timeseries": { + "timeField": "time", + "metaField": "meta", + "granularity": "minutes" + } + } + }, + { + "name": "iterateUntilDocumentOrError", + "object": "changeStream0", + "expectResult": { + "operationType": "create", + "nsType": "timeseries" + } + } + ] + }, + { + "description": "nsType is present when creating views", + "operations": [ + { + "name": "dropCollection", + "object": "database0", + "arguments": { + "collection": "foo" + } + }, + { + "name": "createChangeStream", + "object": "database0", + "arguments": { + "pipeline": [], + "showExpandedEvents": true + }, + "saveResultAsEntity": "changeStream0" + }, + { + "name": "createCollection", + "object": "database0", + "arguments": { + "collection": "foo", + "viewOn": "testName" + } + }, + { + "name": "iterateUntilDocumentOrError", + "object": "changeStream0", + "expectResult": { + "operationType": "create", + "nsType": "view" + } + } + ] + } + ] +} diff --git a/test/change_streams/unified/change-streams-pre_and_post_images.json b/test/change_streams/unified/change-streams-pre_and_post_images.json index 8beefb2bc8..e62fc03459 100644 --- a/test/change_streams/unified/change-streams-pre_and_post_images.json +++ b/test/change_streams/unified/change-streams-pre_and_post_images.json @@ -6,7 +6,7 @@ "minServerVersion": "6.0.0", "topologies": [ "replicaset", - "sharded-replicaset", + "sharded", "load-balanced" ], "serverless": "forbid" diff --git a/test/change_streams/unified/change-streams-resume-allowlist.json b/test/change_streams/unified/change-streams-resume-allowlist.json index b4953ec736..1ec72b432b 100644 --- a/test/change_streams/unified/change-streams-resume-allowlist.json +++ b/test/change_streams/unified/change-streams-resume-allowlist.json @@ -6,7 +6,7 @@ "minServerVersion": "3.6", "topologies": [ "replicaset", - "sharded-replicaset", + "sharded", "load-balanced" ], "serverless": "forbid" diff --git a/test/change_streams/unified/change-streams-resume-errorLabels.json b/test/change_streams/unified/change-streams-resume-errorLabels.json index f5f4505a9f..7fd70108f0 100644 --- a/test/change_streams/unified/change-streams-resume-errorLabels.json +++ b/test/change_streams/unified/change-streams-resume-errorLabels.json @@ -6,7 +6,7 @@ "minServerVersion": "4.3.1", "topologies": [ "replicaset", - "sharded-replicaset", + "sharded", "load-balanced" ], "serverless": "forbid" diff --git a/test/change_streams/unified/change-streams-showExpandedEvents.json b/test/change_streams/unified/change-streams-showExpandedEvents.json index 3eed2f534a..b9594e0c1e 100644 --- a/test/change_streams/unified/change-streams-showExpandedEvents.json +++ b/test/change_streams/unified/change-streams-showExpandedEvents.json @@ -6,9 +6,9 @@ "minServerVersion": "6.0.0", "topologies": [ "replicaset", - "sharded-replicaset", "sharded" - ] + ], + "serverless": "forbid" } ], "createEntities": [ @@ -462,7 +462,6 @@ "runOnRequirements": [ { "topologies": [ - "sharded-replicaset", "sharded" ] } diff --git a/test/change_streams/unified/change-streams.json b/test/change_streams/unified/change-streams.json index c8b60ed4e2..a155d85b6e 100644 --- a/test/change_streams/unified/change-streams.json +++ b/test/change_streams/unified/change-streams.json @@ -181,7 +181,12 @@ "field": "array", "newSize": 2 } - ] + ], + "disambiguatedPaths": { + "$$unsetOrMatches": { + "$$exists": true + } + } } } } @@ -1408,6 +1413,11 @@ "$$unsetOrMatches": { "$$exists": true } + }, + "disambiguatedPaths": { + "$$unsetOrMatches": { + "$$exists": true + } } } } diff --git a/test/client-side-encryption/etc/data/lookup/key-doc.json b/test/client-side-encryption/etc/data/lookup/key-doc.json new file mode 100644 index 0000000000..566b56c354 --- /dev/null +++ b/test/client-side-encryption/etc/data/lookup/key-doc.json @@ -0,0 +1,30 @@ +{ + "_id": { + "$binary": { + "base64": "EjRWeBI0mHYSNBI0VniQEg==", + "subType": "04" + } + }, + "keyMaterial": { + "$binary": { + "base64": "sHe0kz57YW7v8g9VP9sf/+K1ex4JqKc5rf/URX3n3p8XdZ6+15uXPaSayC6adWbNxkFskuMCOifDoTT+rkqMtFkDclOy884RuGGtUysq3X7zkAWYTKi8QAfKkajvVbZl2y23UqgVasdQu3OVBQCrH/xY00nNAs/52e958nVjBuzQkSb1T8pKJAyjZsHJ60+FtnfafDZSTAIBJYn7UWBCwQ==", + "subType": "00" + } + }, + "creationDate": { + "$date": { + "$numberLong": "1648914851981" + } + }, + "updateDate": { + "$date": { + "$numberLong": "1648914851981" + } + }, + "status": { + "$numberInt": "0" + }, + "masterKey": { + "provider": "local" + } +} diff --git a/test/client-side-encryption/etc/data/lookup/schema-csfle.json b/test/client-side-encryption/etc/data/lookup/schema-csfle.json new file mode 100644 index 0000000000..29ac9ad5da --- /dev/null +++ b/test/client-side-encryption/etc/data/lookup/schema-csfle.json @@ -0,0 +1,19 @@ +{ + "properties": { + "csfle": { + "encrypt": { + "keyId": [ + { + "$binary": { + "base64": "EjRWeBI0mHYSNBI0VniQEg==", + "subType": "04" + } + } + ], + "bsonType": "string", + "algorithm": "AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic" + } + } + }, + "bsonType": "object" +} diff --git a/test/client-side-encryption/etc/data/lookup/schema-csfle2.json b/test/client-side-encryption/etc/data/lookup/schema-csfle2.json new file mode 100644 index 0000000000..3f1c02781c --- /dev/null +++ b/test/client-side-encryption/etc/data/lookup/schema-csfle2.json @@ -0,0 +1,19 @@ +{ + "properties": { + "csfle2": { + "encrypt": { + "keyId": [ + { + "$binary": { + "base64": "EjRWeBI0mHYSNBI0VniQEg==", + "subType": "04" + } + } + ], + "bsonType": "string", + "algorithm": "AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic" + } + } + }, + "bsonType": "object" +} diff --git a/test/client-side-encryption/etc/data/lookup/schema-qe.json b/test/client-side-encryption/etc/data/lookup/schema-qe.json new file mode 100644 index 0000000000..9428ea1b45 --- /dev/null +++ b/test/client-side-encryption/etc/data/lookup/schema-qe.json @@ -0,0 +1,20 @@ +{ + "escCollection": "enxcol_.qe.esc", + "ecocCollection": "enxcol_.qe.ecoc", + "fields": [ + { + "keyId": { + "$binary": { + "base64": "EjRWeBI0mHYSNBI0VniQEg==", + "subType": "04" + } + }, + "path": "qe", + "bsonType": "string", + "queries": { + "queryType": "equality", + "contention": 0 + } + } + ] +} diff --git a/test/client-side-encryption/etc/data/lookup/schema-qe2.json b/test/client-side-encryption/etc/data/lookup/schema-qe2.json new file mode 100644 index 0000000000..77d5bd37cb --- /dev/null +++ b/test/client-side-encryption/etc/data/lookup/schema-qe2.json @@ -0,0 +1,20 @@ +{ + "escCollection": "enxcol_.qe2.esc", + "ecocCollection": "enxcol_.qe2.ecoc", + "fields": [ + { + "keyId": { + "$binary": { + "base64": "EjRWeBI0mHYSNBI0VniQEg==", + "subType": "04" + } + }, + "path": "qe2", + "bsonType": "string", + "queries": { + "queryType": "equality", + "contention": 0 + } + } + ] +} diff --git a/test/client-side-encryption/spec/legacy/fle2v2-Rangev2-Compact.json b/test/client-side-encryption/spec/legacy/fle2v2-Rangev2-Compact.json index bba9f25535..59241927ca 100644 --- a/test/client-side-encryption/spec/legacy/fle2v2-Rangev2-Compact.json +++ b/test/client-side-encryption/spec/legacy/fle2v2-Rangev2-Compact.json @@ -6,8 +6,7 @@ "replicaset", "sharded", "load-balanced" - ], - "serverless": "forbid" + ] } ], "database_name": "default", diff --git a/test/client-side-encryption/spec/legacy/keyCache.json b/test/client-side-encryption/spec/legacy/keyCache.json new file mode 100644 index 0000000000..912ce80020 --- /dev/null +++ b/test/client-side-encryption/spec/legacy/keyCache.json @@ -0,0 +1,270 @@ +{ + "runOn": [ + { + "minServerVersion": "4.1.10" + } + ], + "database_name": "default", + "collection_name": "default", + "data": [], + "json_schema": { + "properties": { + "encrypted_w_altname": { + "encrypt": { + "keyId": "/altname", + "bsonType": "string", + "algorithm": "AEAD_AES_256_CBC_HMAC_SHA_512-Random" + } + }, + "encrypted_string": { + "encrypt": { + "keyId": [ + { + "$binary": { + "base64": "AAAAAAAAAAAAAAAAAAAAAA==", + "subType": "04" + } + } + ], + "bsonType": "string", + "algorithm": "AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic" + } + }, + "random": { + "encrypt": { + "keyId": [ + { + "$binary": { + "base64": "AAAAAAAAAAAAAAAAAAAAAA==", + "subType": "04" + } + } + ], + "bsonType": "string", + "algorithm": "AEAD_AES_256_CBC_HMAC_SHA_512-Random" + } + }, + "encrypted_string_equivalent": { + "encrypt": { + "keyId": [ + { + "$binary": { + "base64": "AAAAAAAAAAAAAAAAAAAAAA==", + "subType": "04" + } + } + ], + "bsonType": "string", + "algorithm": "AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic" + } + } + }, + "bsonType": "object" + }, + "key_vault_data": [ + { + "status": 1, + "_id": { + "$binary": { + "base64": "AAAAAAAAAAAAAAAAAAAAAA==", + "subType": "04" + } + }, + "masterKey": { + "provider": "aws", + "key": "arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0", + "region": "us-east-1" + }, + "updateDate": { + "$date": { + "$numberLong": "1552949630483" + } + }, + "keyMaterial": { + "$binary": { + "base64": "AQICAHhQNmWG2CzOm1dq3kWLM+iDUZhEqnhJwH9wZVpuZ94A8gEqnsxXlR51T5EbEVezUqqKAAAAwjCBvwYJKoZIhvcNAQcGoIGxMIGuAgEAMIGoBgkqhkiG9w0BBwEwHgYJYIZIAWUDBAEuMBEEDHa4jo6yp0Z18KgbUgIBEIB74sKxWtV8/YHje5lv5THTl0HIbhSwM6EqRlmBiFFatmEWaeMk4tO4xBX65eq670I5TWPSLMzpp8ncGHMmvHqRajNBnmFtbYxN3E3/WjxmdbOOe+OXpnGJPcGsftc7cB2shRfA4lICPnE26+oVNXT6p0Lo20nY5XC7jyCO", + "subType": "00" + } + }, + "creationDate": { + "$date": { + "$numberLong": "1552949630483" + } + }, + "keyAltNames": [ + "altname", + "another_altname" + ] + } + ], + "tests": [ + { + "description": "Insert with deterministic encryption, then find it", + "clientOptions": { + "autoEncryptOpts": { + "kmsProviders": { + "aws": {} + }, + "keyExpirationMS": 1 + } + }, + "operations": [ + { + "name": "insertOne", + "arguments": { + "document": { + "_id": 1, + "encrypted_string": "string0" + } + } + }, + { + "name": "wait", + "object": "testRunner", + "arguments": { + "ms": 50 + } + }, + { + "name": "find", + "arguments": { + "filter": { + "_id": 1 + } + }, + "result": [ + { + "_id": 1, + "encrypted_string": "string0" + } + ] + } + ], + "expectations": [ + { + "command_started_event": { + "command": { + "listCollections": 1, + "filter": { + "name": "default" + } + }, + "command_name": "listCollections" + } + }, + { + "command_started_event": { + "command": { + "find": "datakeys", + "filter": { + "$or": [ + { + "_id": { + "$in": [ + { + "$binary": { + "base64": "AAAAAAAAAAAAAAAAAAAAAA==", + "subType": "04" + } + } + ] + } + }, + { + "keyAltNames": { + "$in": [] + } + } + ] + }, + "$db": "keyvault", + "readConcern": { + "level": "majority" + } + }, + "command_name": "find" + } + }, + { + "command_started_event": { + "command": { + "insert": "default", + "documents": [ + { + "_id": 1, + "encrypted_string": { + "$binary": { + "base64": "AQAAAAAAAAAAAAAAAAAAAAACwj+3zkv2VM+aTfk60RqhXq6a/77WlLwu/BxXFkL7EppGsju/m8f0x5kBDD3EZTtGALGXlym5jnpZAoSIkswHoA==", + "subType": "06" + } + } + } + ], + "ordered": true + }, + "command_name": "insert" + } + }, + { + "command_started_event": { + "command": { + "find": "default", + "filter": { + "_id": 1 + } + }, + "command_name": "find" + } + }, + { + "command_started_event": { + "command": { + "find": "datakeys", + "filter": { + "$or": [ + { + "_id": { + "$in": [ + { + "$binary": { + "base64": "AAAAAAAAAAAAAAAAAAAAAA==", + "subType": "04" + } + } + ] + } + }, + { + "keyAltNames": { + "$in": [] + } + } + ] + }, + "$db": "keyvault", + "readConcern": { + "level": "majority" + } + }, + "command_name": "find" + } + } + ], + "outcome": { + "collection": { + "data": [ + { + "_id": 1, + "encrypted_string": { + "$binary": { + "base64": "AQAAAAAAAAAAAAAAAAAAAAACwj+3zkv2VM+aTfk60RqhXq6a/77WlLwu/BxXFkL7EppGsju/m8f0x5kBDD3EZTtGALGXlym5jnpZAoSIkswHoA==", + "subType": "06" + } + } + } + ] + } + } + } + ] +} diff --git a/test/client-side-encryption/spec/legacy/timeoutMS.json b/test/client-side-encryption/spec/legacy/timeoutMS.json index 8411306224..b667767cfc 100644 --- a/test/client-side-encryption/spec/legacy/timeoutMS.json +++ b/test/client-side-encryption/spec/legacy/timeoutMS.json @@ -110,7 +110,7 @@ "listCollections" ], "blockConnection": true, - "blockTimeMS": 600 + "blockTimeMS": 60 } }, "clientOptions": { @@ -119,7 +119,7 @@ "aws": {} } }, - "timeoutMS": 500 + "timeoutMS": 50 }, "operations": [ { diff --git a/test/client-side-encryption/spec/unified/keyCache.json b/test/client-side-encryption/spec/unified/keyCache.json new file mode 100644 index 0000000000..a39701e286 --- /dev/null +++ b/test/client-side-encryption/spec/unified/keyCache.json @@ -0,0 +1,198 @@ +{ + "description": "keyCache-explicit", + "schemaVersion": "1.22", + "runOnRequirements": [ + { + "csfle": true + } + ], + "createEntities": [ + { + "client": { + "id": "client0", + "observeEvents": [ + "commandStartedEvent" + ] + } + }, + { + "clientEncryption": { + "id": "clientEncryption0", + "clientEncryptionOpts": { + "keyVaultClient": "client0", + "keyVaultNamespace": "keyvault.datakeys", + "kmsProviders": { + "local": { + "key": "OCTP9uKPPmvuqpHlqq83gPk4U6rUPxKVRRyVtrjFmVjdoa4Xzm1SzUbr7aIhNI42czkUBmrCtZKF31eaaJnxEBkqf0RFukA9Mo3NEHQWgAQ2cn9duOcRbaFUQo2z0/rB" + } + }, + "keyExpirationMS": 1 + } + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "keyvault" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "datakeys" + } + } + ], + "initialData": [ + { + "databaseName": "keyvault", + "collectionName": "datakeys", + "documents": [ + { + "_id": { + "$binary": { + "base64": "a+YWzdygTAG62/cNUkqZiQ==", + "subType": "04" + } + }, + "keyAltNames": [], + "keyMaterial": { + "$binary": { + "base64": "iocBkhO3YBokiJ+FtxDTS71/qKXQ7tSWhWbcnFTXBcMjarsepvALeJ5li+SdUd9ePuatjidxAdMo7vh1V2ZESLMkQWdpPJ9PaJjA67gKQKbbbB4Ik5F2uKjULvrMBnFNVRMup4JNUwWFQJpqbfMveXnUVcD06+pUpAkml/f+DSXrV3e5rxciiNVtz03dAG8wJrsKsFXWj6vTjFhsfknyBA==", + "subType": "00" + } + }, + "creationDate": { + "$date": { + "$numberLong": "1552949630483" + } + }, + "updateDate": { + "$date": { + "$numberLong": "1552949630483" + } + }, + "status": { + "$numberInt": "0" + }, + "masterKey": { + "provider": "local" + } + } + ] + } + ], + "tests": [ + { + "description": "decrypt, wait, and decrypt again", + "operations": [ + { + "name": "decrypt", + "object": "clientEncryption0", + "arguments": { + "value": { + "$binary": { + "base64": "AWvmFs3coEwButv3DVJKmYkCJ6lUzRX9R28WNlw5uyndb+8gurA+p8q14s7GZ04K2ZvghieRlAr5UwZbow3PMq27u5EIhDDczwBFcbdP1amllw==", + "subType": "06" + } + } + }, + "expectResult": "foobar" + }, + { + "name": "wait", + "object": "testRunner", + "arguments": { + "ms": 50 + } + }, + { + "name": "decrypt", + "object": "clientEncryption0", + "arguments": { + "value": { + "$binary": { + "base64": "AWvmFs3coEwButv3DVJKmYkCJ6lUzRX9R28WNlw5uyndb+8gurA+p8q14s7GZ04K2ZvghieRlAr5UwZbow3PMq27u5EIhDDczwBFcbdP1amllw==", + "subType": "06" + } + } + }, + "expectResult": "foobar" + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "find": "datakeys", + "filter": { + "$or": [ + { + "_id": { + "$in": [ + { + "$binary": { + "base64": "a+YWzdygTAG62/cNUkqZiQ==", + "subType": "04" + } + } + ] + } + }, + { + "keyAltNames": { + "$in": [] + } + } + ] + }, + "$db": "keyvault", + "readConcern": { + "level": "majority" + } + } + } + }, + { + "commandStartedEvent": { + "command": { + "find": "datakeys", + "filter": { + "$or": [ + { + "_id": { + "$in": [ + { + "$binary": { + "base64": "a+YWzdygTAG62/cNUkqZiQ==", + "subType": "04" + } + } + ] + } + }, + { + "keyAltNames": { + "$in": [] + } + } + ] + }, + "$db": "keyvault", + "readConcern": { + "level": "majority" + } + } + } + } + ] + } + ] + } + ] +} diff --git a/test/command_monitoring/unacknowledged-client-bulkWrite.json b/test/command_monitoring/unacknowledged-client-bulkWrite.json index 61bb00726c..14740cea34 100644 --- a/test/command_monitoring/unacknowledged-client-bulkWrite.json +++ b/test/command_monitoring/unacknowledged-client-bulkWrite.json @@ -95,29 +95,34 @@ "ordered": false }, "expectResult": { - "insertedCount": { - "$$unsetOrMatches": 0 - }, - "upsertedCount": { - "$$unsetOrMatches": 0 - }, - "matchedCount": { - "$$unsetOrMatches": 0 - }, - "modifiedCount": { - "$$unsetOrMatches": 0 - }, - "deletedCount": { - "$$unsetOrMatches": 0 - }, - "insertResults": { - "$$unsetOrMatches": {} - }, - "updateResults": { - "$$unsetOrMatches": {} - }, - "deleteResults": { - "$$unsetOrMatches": {} + "$$unsetOrMatches": { + "acknowledged": { + "$$unsetOrMatches": false + }, + "insertedCount": { + "$$unsetOrMatches": 0 + }, + "upsertedCount": { + "$$unsetOrMatches": 0 + }, + "matchedCount": { + "$$unsetOrMatches": 0 + }, + "modifiedCount": { + "$$unsetOrMatches": 0 + }, + "deletedCount": { + "$$unsetOrMatches": 0 + }, + "insertResults": { + "$$unsetOrMatches": {} + }, + "updateResults": { + "$$unsetOrMatches": {} + }, + "deleteResults": { + "$$unsetOrMatches": {} + } } } }, diff --git a/test/connection_string/test/valid-options.json b/test/connection_string/test/valid-options.json index 6c86172d08..e094bcf606 100644 --- a/test/connection_string/test/valid-options.json +++ b/test/connection_string/test/valid-options.json @@ -40,7 +40,7 @@ }, { "description": "Colon in a key value pair", - "uri": "mongodb://example.com/?authMechanism=MONGODB-OIDC&authMechanismProperties=TOKEN_RESOURCE:mongodb://test-cluster", + "uri": "mongodb://example.com/?authMechanism=MONGODB-OIDC&authMechanismProperties=TOKEN_RESOURCE:mongodb://test-cluster,ENVIRONMENT:azure", "valid": true, "warning": false, "hosts": [ @@ -53,9 +53,10 @@ "auth": null, "options": { "authmechanismProperties": { - "TOKEN_RESOURCE": "mongodb://test-cluster" + "TOKEN_RESOURCE": "mongodb://test-cluster", + "ENVIRONMENT": "azure" } } } ] -} +} \ No newline at end of file diff --git a/test/connection_string/test/valid-warnings.json b/test/connection_string/test/valid-warnings.json index daf814a75f..c46a8311c5 100644 --- a/test/connection_string/test/valid-warnings.json +++ b/test/connection_string/test/valid-warnings.json @@ -96,7 +96,7 @@ }, { "description": "Comma in a key value pair causes a warning", - "uri": "mongodb://localhost?authMechanism=MONGODB-OIDC&authMechanismProperties=TOKEN_RESOURCE:mongodb://host1%2Chost2", + "uri": "mongodb://localhost?authMechanism=MONGODB-OIDC&authMechanismProperties=TOKEN_RESOURCE:mongodb://host1%2Chost2,ENVIRONMENT:azure", "valid": true, "warning": true, "hosts": [ @@ -112,4 +112,4 @@ } } ] -} +} \ No newline at end of file diff --git a/test/crud/unified/bulkWrite-updateMany-pipeline.json b/test/crud/unified/bulkWrite-updateMany-pipeline.json new file mode 100644 index 0000000000..e938ea7535 --- /dev/null +++ b/test/crud/unified/bulkWrite-updateMany-pipeline.json @@ -0,0 +1,148 @@ +{ + "description": "bulkWrite-updateMany-pipeline", + "schemaVersion": "1.0", + "runOnRequirements": [ + { + "minServerVersion": "4.1.11" + } + ], + "createEntities": [ + { + "client": { + "id": "client0", + "observeEvents": [ + "commandStartedEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "crud-tests" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "test" + } + } + ], + "initialData": [ + { + "collectionName": "test", + "databaseName": "crud-tests", + "documents": [ + { + "_id": 1, + "x": 1, + "y": 1, + "t": { + "u": { + "v": 1 + } + } + }, + { + "_id": 2, + "x": 2, + "y": 1 + } + ] + } + ], + "tests": [ + { + "description": "UpdateMany in bulk write using pipelines", + "operations": [ + { + "object": "collection0", + "name": "bulkWrite", + "arguments": { + "requests": [ + { + "updateMany": { + "filter": {}, + "update": [ + { + "$project": { + "x": 1 + } + }, + { + "$addFields": { + "foo": 1 + } + } + ] + } + } + ] + }, + "expectResult": { + "matchedCount": 2, + "modifiedCount": 2, + "upsertedCount": 0 + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "update": "test", + "updates": [ + { + "q": {}, + "u": [ + { + "$project": { + "x": 1 + } + }, + { + "$addFields": { + "foo": 1 + } + } + ], + "multi": true, + "upsert": { + "$$unsetOrMatches": false + } + } + ] + }, + "commandName": "update", + "databaseName": "crud-tests" + } + } + ] + } + ], + "outcome": [ + { + "collectionName": "test", + "databaseName": "crud-tests", + "documents": [ + { + "_id": 1, + "x": 1, + "foo": 1 + }, + { + "_id": 2, + "x": 2, + "foo": 1 + } + ] + } + ] + } + ] +} diff --git a/test/crud/unified/bulkWrite-updateOne-pipeline.json b/test/crud/unified/bulkWrite-updateOne-pipeline.json new file mode 100644 index 0000000000..769bd106f8 --- /dev/null +++ b/test/crud/unified/bulkWrite-updateOne-pipeline.json @@ -0,0 +1,156 @@ +{ + "description": "bulkWrite-updateOne-pipeline", + "schemaVersion": "1.0", + "runOnRequirements": [ + { + "minServerVersion": "4.1.11" + } + ], + "createEntities": [ + { + "client": { + "id": "client0", + "observeEvents": [ + "commandStartedEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "crud-tests" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "test" + } + } + ], + "initialData": [ + { + "collectionName": "test", + "databaseName": "crud-tests", + "documents": [ + { + "_id": 1, + "x": 1, + "y": 1, + "t": { + "u": { + "v": 1 + } + } + }, + { + "_id": 2, + "x": 2, + "y": 1 + } + ] + } + ], + "tests": [ + { + "description": "UpdateOne in bulk write using pipelines", + "operations": [ + { + "object": "collection0", + "name": "bulkWrite", + "arguments": { + "requests": [ + { + "updateOne": { + "filter": { + "_id": 1 + }, + "update": [ + { + "$replaceRoot": { + "newRoot": "$t" + } + }, + { + "$addFields": { + "foo": 1 + } + } + ] + } + } + ] + }, + "expectResult": { + "matchedCount": 1, + "modifiedCount": 1, + "upsertedCount": 0 + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "update": "test", + "updates": [ + { + "q": { + "_id": 1 + }, + "u": [ + { + "$replaceRoot": { + "newRoot": "$t" + } + }, + { + "$addFields": { + "foo": 1 + } + } + ], + "multi": { + "$$unsetOrMatches": false + }, + "upsert": { + "$$unsetOrMatches": false + } + } + ] + }, + "commandName": "update", + "databaseName": "crud-tests" + } + } + ] + } + ], + "outcome": [ + { + "collectionName": "test", + "databaseName": "crud-tests", + "documents": [ + { + "_id": 1, + "u": { + "v": 1 + }, + "foo": 1 + }, + { + "_id": 2, + "x": 2, + "y": 1 + } + ] + } + ] + } + ] +} diff --git a/test/crud/unified/bypassDocumentValidation.json b/test/crud/unified/bypassDocumentValidation.json new file mode 100644 index 0000000000..aff2d37f81 --- /dev/null +++ b/test/crud/unified/bypassDocumentValidation.json @@ -0,0 +1,493 @@ +{ + "description": "bypassDocumentValidation", + "schemaVersion": "1.4", + "runOnRequirements": [ + { + "minServerVersion": "3.2", + "serverless": "forbid" + } + ], + "createEntities": [ + { + "client": { + "id": "client0", + "observeEvents": [ + "commandStartedEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "crud" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "coll" + } + } + ], + "initialData": [ + { + "collectionName": "coll", + "databaseName": "crud", + "documents": [ + { + "_id": 1, + "x": 11 + }, + { + "_id": 2, + "x": 22 + }, + { + "_id": 3, + "x": 33 + } + ] + } + ], + "tests": [ + { + "description": "Aggregate with $out passes bypassDocumentValidation: false", + "operations": [ + { + "object": "collection0", + "name": "aggregate", + "arguments": { + "pipeline": [ + { + "$sort": { + "x": 1 + } + }, + { + "$match": { + "_id": { + "$gt": 1 + } + } + }, + { + "$out": "other_test_collection" + } + ], + "bypassDocumentValidation": false + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "aggregate": "coll", + "pipeline": [ + { + "$sort": { + "x": 1 + } + }, + { + "$match": { + "_id": { + "$gt": 1 + } + } + }, + { + "$out": "other_test_collection" + } + ], + "bypassDocumentValidation": false + }, + "commandName": "aggregate", + "databaseName": "crud" + } + } + ] + } + ] + }, + { + "description": "BulkWrite passes bypassDocumentValidation: false", + "operations": [ + { + "object": "collection0", + "name": "bulkWrite", + "arguments": { + "requests": [ + { + "insertOne": { + "document": { + "_id": 4, + "x": 44 + } + } + } + ], + "bypassDocumentValidation": false + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "insert": "coll", + "documents": [ + { + "_id": 4, + "x": 44 + } + ], + "bypassDocumentValidation": false + } + } + } + ] + } + ] + }, + { + "description": "FindOneAndReplace passes bypassDocumentValidation: false", + "operations": [ + { + "object": "collection0", + "name": "findOneAndReplace", + "arguments": { + "filter": { + "_id": { + "$gt": 1 + } + }, + "replacement": { + "x": 32 + }, + "bypassDocumentValidation": false + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "findAndModify": "coll", + "query": { + "_id": { + "$gt": 1 + } + }, + "update": { + "x": 32 + }, + "bypassDocumentValidation": false + } + } + } + ] + } + ] + }, + { + "description": "FindOneAndUpdate passes bypassDocumentValidation: false", + "operations": [ + { + "object": "collection0", + "name": "findOneAndUpdate", + "arguments": { + "filter": { + "_id": { + "$gt": 1 + } + }, + "update": { + "$inc": { + "x": 1 + } + }, + "bypassDocumentValidation": false + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "findAndModify": "coll", + "query": { + "_id": { + "$gt": 1 + } + }, + "update": { + "$inc": { + "x": 1 + } + }, + "bypassDocumentValidation": false + } + } + } + ] + } + ] + }, + { + "description": "InsertMany passes bypassDocumentValidation: false", + "operations": [ + { + "object": "collection0", + "name": "insertMany", + "arguments": { + "documents": [ + { + "_id": 4, + "x": 44 + } + ], + "bypassDocumentValidation": false + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "insert": "coll", + "documents": [ + { + "_id": 4, + "x": 44 + } + ], + "bypassDocumentValidation": false + } + } + } + ] + } + ] + }, + { + "description": "InsertOne passes bypassDocumentValidation: false", + "operations": [ + { + "object": "collection0", + "name": "insertOne", + "arguments": { + "document": { + "_id": 4, + "x": 44 + }, + "bypassDocumentValidation": false + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "insert": "coll", + "documents": [ + { + "_id": 4, + "x": 44 + } + ], + "bypassDocumentValidation": false + } + } + } + ] + } + ] + }, + { + "description": "ReplaceOne passes bypassDocumentValidation: false", + "operations": [ + { + "object": "collection0", + "name": "replaceOne", + "arguments": { + "filter": { + "_id": { + "$gt": 1 + } + }, + "replacement": { + "x": 32 + }, + "bypassDocumentValidation": false + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "update": "coll", + "updates": [ + { + "q": { + "_id": { + "$gt": 1 + } + }, + "u": { + "x": 32 + }, + "multi": { + "$$unsetOrMatches": false + }, + "upsert": { + "$$unsetOrMatches": false + } + } + ], + "bypassDocumentValidation": false + } + } + } + ] + } + ] + }, + { + "description": "UpdateMany passes bypassDocumentValidation: false", + "operations": [ + { + "object": "collection0", + "name": "updateMany", + "arguments": { + "filter": { + "_id": { + "$gt": 1 + } + }, + "update": { + "$inc": { + "x": 1 + } + }, + "bypassDocumentValidation": false + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "update": "coll", + "updates": [ + { + "q": { + "_id": { + "$gt": 1 + } + }, + "u": { + "$inc": { + "x": 1 + } + }, + "multi": true, + "upsert": { + "$$unsetOrMatches": false + } + } + ], + "bypassDocumentValidation": false + } + } + } + ] + } + ] + }, + { + "description": "UpdateOne passes bypassDocumentValidation: false", + "operations": [ + { + "object": "collection0", + "name": "updateOne", + "arguments": { + "filter": { + "_id": { + "$gt": 1 + } + }, + "update": { + "$inc": { + "x": 1 + } + }, + "bypassDocumentValidation": false + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "update": "coll", + "updates": [ + { + "q": { + "_id": { + "$gt": 1 + } + }, + "u": { + "$inc": { + "x": 1 + } + }, + "multi": { + "$$unsetOrMatches": false + }, + "upsert": { + "$$unsetOrMatches": false + } + } + ], + "bypassDocumentValidation": false + } + } + } + ] + } + ] + } + ] +} diff --git a/test/crud/unified/findOneAndUpdate-pipeline.json b/test/crud/unified/findOneAndUpdate-pipeline.json new file mode 100644 index 0000000000..81dba9ae93 --- /dev/null +++ b/test/crud/unified/findOneAndUpdate-pipeline.json @@ -0,0 +1,130 @@ +{ + "description": "findOneAndUpdate-pipeline", + "schemaVersion": "1.0", + "runOnRequirements": [ + { + "minServerVersion": "4.1.11" + } + ], + "createEntities": [ + { + "client": { + "id": "client0", + "observeEvents": [ + "commandStartedEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "crud-tests" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "test" + } + } + ], + "initialData": [ + { + "collectionName": "test", + "databaseName": "crud-tests", + "documents": [ + { + "_id": 1, + "x": 1, + "y": 1, + "t": { + "u": { + "v": 1 + } + } + }, + { + "_id": 2, + "x": 2, + "y": 1 + } + ] + } + ], + "tests": [ + { + "description": "FindOneAndUpdate using pipelines", + "operations": [ + { + "object": "collection0", + "name": "findOneAndUpdate", + "arguments": { + "filter": { + "_id": 1 + }, + "update": [ + { + "$project": { + "x": 1 + } + }, + { + "$addFields": { + "foo": 1 + } + } + ] + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "findAndModify": "test", + "update": [ + { + "$project": { + "x": 1 + } + }, + { + "$addFields": { + "foo": 1 + } + } + ] + }, + "commandName": "findAndModify", + "databaseName": "crud-tests" + } + } + ] + } + ], + "outcome": [ + { + "collectionName": "test", + "databaseName": "crud-tests", + "documents": [ + { + "_id": 1, + "x": 1, + "foo": 1 + }, + { + "_id": 2, + "x": 2, + "y": 1 + } + ] + } + ] + } + ] +} diff --git a/test/crud/unified/updateMany-pipeline.json b/test/crud/unified/updateMany-pipeline.json new file mode 100644 index 0000000000..e0f6d9d4a4 --- /dev/null +++ b/test/crud/unified/updateMany-pipeline.json @@ -0,0 +1,142 @@ +{ + "description": "updateMany-pipeline", + "schemaVersion": "1.0", + "runOnRequirements": [ + { + "minServerVersion": "4.1.11" + } + ], + "createEntities": [ + { + "client": { + "id": "client0", + "observeEvents": [ + "commandStartedEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "crud-tests" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "test" + } + } + ], + "initialData": [ + { + "collectionName": "test", + "databaseName": "crud-tests", + "documents": [ + { + "_id": 1, + "x": 1, + "y": 1, + "t": { + "u": { + "v": 1 + } + } + }, + { + "_id": 2, + "x": 2, + "y": 1 + } + ] + } + ], + "tests": [ + { + "description": "UpdateMany using pipelines", + "operations": [ + { + "object": "collection0", + "name": "updateMany", + "arguments": { + "filter": {}, + "update": [ + { + "$project": { + "x": 1 + } + }, + { + "$addFields": { + "foo": 1 + } + } + ] + }, + "expectResult": { + "matchedCount": 2, + "modifiedCount": 2, + "upsertedCount": 0 + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "update": "test", + "updates": [ + { + "q": {}, + "u": [ + { + "$project": { + "x": 1 + } + }, + { + "$addFields": { + "foo": 1 + } + } + ], + "multi": true, + "upsert": { + "$$unsetOrMatches": false + } + } + ] + }, + "commandName": "update", + "databaseName": "crud-tests" + } + } + ] + } + ], + "outcome": [ + { + "collectionName": "test", + "databaseName": "crud-tests", + "documents": [ + { + "_id": 1, + "x": 1, + "foo": 1 + }, + { + "_id": 2, + "x": 2, + "foo": 1 + } + ] + } + ] + } + ] +} diff --git a/test/crud/unified/updateOne-pipeline.json b/test/crud/unified/updateOne-pipeline.json new file mode 100644 index 0000000000..1348c6b53b --- /dev/null +++ b/test/crud/unified/updateOne-pipeline.json @@ -0,0 +1,150 @@ +{ + "description": "updateOne-pipeline", + "schemaVersion": "1.0", + "runOnRequirements": [ + { + "minServerVersion": "4.1.11" + } + ], + "createEntities": [ + { + "client": { + "id": "client0", + "observeEvents": [ + "commandStartedEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "crud-tests" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "test" + } + } + ], + "initialData": [ + { + "collectionName": "test", + "databaseName": "crud-tests", + "documents": [ + { + "_id": 1, + "x": 1, + "y": 1, + "t": { + "u": { + "v": 1 + } + } + }, + { + "_id": 2, + "x": 2, + "y": 1 + } + ] + } + ], + "tests": [ + { + "description": "UpdateOne using pipelines", + "operations": [ + { + "object": "collection0", + "name": "updateOne", + "arguments": { + "filter": { + "_id": 1 + }, + "update": [ + { + "$replaceRoot": { + "newRoot": "$t" + } + }, + { + "$addFields": { + "foo": 1 + } + } + ] + }, + "expectResult": { + "matchedCount": 1, + "modifiedCount": 1, + "upsertedCount": 0 + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "update": "test", + "updates": [ + { + "q": { + "_id": 1 + }, + "u": [ + { + "$replaceRoot": { + "newRoot": "$t" + } + }, + { + "$addFields": { + "foo": 1 + } + } + ], + "multi": { + "$$unsetOrMatches": false + }, + "upsert": { + "$$unsetOrMatches": false + } + } + ] + }, + "commandName": "update", + "databaseName": "crud-tests" + } + } + ] + } + ], + "outcome": [ + { + "collectionName": "test", + "databaseName": "crud-tests", + "documents": [ + { + "_id": 1, + "u": { + "v": 1 + }, + "foo": 1 + }, + { + "_id": 2, + "x": 2, + "y": 1 + } + ] + } + ] + } + ] +} diff --git a/test/crud/unified/updateWithPipelines.json b/test/crud/unified/updateWithPipelines.json deleted file mode 100644 index 164f2f6a19..0000000000 --- a/test/crud/unified/updateWithPipelines.json +++ /dev/null @@ -1,494 +0,0 @@ -{ - "description": "updateWithPipelines", - "schemaVersion": "1.0", - "runOnRequirements": [ - { - "minServerVersion": "4.1.11" - } - ], - "createEntities": [ - { - "client": { - "id": "client0", - "observeEvents": [ - "commandStartedEvent" - ] - } - }, - { - "database": { - "id": "database0", - "client": "client0", - "databaseName": "crud-tests" - } - }, - { - "collection": { - "id": "collection0", - "database": "database0", - "collectionName": "test" - } - } - ], - "initialData": [ - { - "collectionName": "test", - "databaseName": "crud-tests", - "documents": [ - { - "_id": 1, - "x": 1, - "y": 1, - "t": { - "u": { - "v": 1 - } - } - }, - { - "_id": 2, - "x": 2, - "y": 1 - } - ] - } - ], - "tests": [ - { - "description": "UpdateOne using pipelines", - "operations": [ - { - "object": "collection0", - "name": "updateOne", - "arguments": { - "filter": { - "_id": 1 - }, - "update": [ - { - "$replaceRoot": { - "newRoot": "$t" - } - }, - { - "$addFields": { - "foo": 1 - } - } - ] - }, - "expectResult": { - "matchedCount": 1, - "modifiedCount": 1, - "upsertedCount": 0 - } - } - ], - "expectEvents": [ - { - "client": "client0", - "events": [ - { - "commandStartedEvent": { - "command": { - "update": "test", - "updates": [ - { - "q": { - "_id": 1 - }, - "u": [ - { - "$replaceRoot": { - "newRoot": "$t" - } - }, - { - "$addFields": { - "foo": 1 - } - } - ], - "multi": { - "$$unsetOrMatches": false - }, - "upsert": { - "$$unsetOrMatches": false - } - } - ] - }, - "commandName": "update", - "databaseName": "crud-tests" - } - } - ] - } - ], - "outcome": [ - { - "collectionName": "test", - "databaseName": "crud-tests", - "documents": [ - { - "_id": 1, - "u": { - "v": 1 - }, - "foo": 1 - }, - { - "_id": 2, - "x": 2, - "y": 1 - } - ] - } - ] - }, - { - "description": "UpdateMany using pipelines", - "operations": [ - { - "object": "collection0", - "name": "updateMany", - "arguments": { - "filter": {}, - "update": [ - { - "$project": { - "x": 1 - } - }, - { - "$addFields": { - "foo": 1 - } - } - ] - }, - "expectResult": { - "matchedCount": 2, - "modifiedCount": 2, - "upsertedCount": 0 - } - } - ], - "expectEvents": [ - { - "client": "client0", - "events": [ - { - "commandStartedEvent": { - "command": { - "update": "test", - "updates": [ - { - "q": {}, - "u": [ - { - "$project": { - "x": 1 - } - }, - { - "$addFields": { - "foo": 1 - } - } - ], - "multi": true, - "upsert": { - "$$unsetOrMatches": false - } - } - ] - }, - "commandName": "update", - "databaseName": "crud-tests" - } - } - ] - } - ], - "outcome": [ - { - "collectionName": "test", - "databaseName": "crud-tests", - "documents": [ - { - "_id": 1, - "x": 1, - "foo": 1 - }, - { - "_id": 2, - "x": 2, - "foo": 1 - } - ] - } - ] - }, - { - "description": "FindOneAndUpdate using pipelines", - "operations": [ - { - "object": "collection0", - "name": "findOneAndUpdate", - "arguments": { - "filter": { - "_id": 1 - }, - "update": [ - { - "$project": { - "x": 1 - } - }, - { - "$addFields": { - "foo": 1 - } - } - ] - } - } - ], - "expectEvents": [ - { - "client": "client0", - "events": [ - { - "commandStartedEvent": { - "command": { - "findAndModify": "test", - "update": [ - { - "$project": { - "x": 1 - } - }, - { - "$addFields": { - "foo": 1 - } - } - ] - }, - "commandName": "findAndModify", - "databaseName": "crud-tests" - } - } - ] - } - ], - "outcome": [ - { - "collectionName": "test", - "databaseName": "crud-tests", - "documents": [ - { - "_id": 1, - "x": 1, - "foo": 1 - }, - { - "_id": 2, - "x": 2, - "y": 1 - } - ] - } - ] - }, - { - "description": "UpdateOne in bulk write using pipelines", - "operations": [ - { - "object": "collection0", - "name": "bulkWrite", - "arguments": { - "requests": [ - { - "updateOne": { - "filter": { - "_id": 1 - }, - "update": [ - { - "$replaceRoot": { - "newRoot": "$t" - } - }, - { - "$addFields": { - "foo": 1 - } - } - ] - } - } - ] - }, - "expectResult": { - "matchedCount": 1, - "modifiedCount": 1, - "upsertedCount": 0 - } - } - ], - "expectEvents": [ - { - "client": "client0", - "events": [ - { - "commandStartedEvent": { - "command": { - "update": "test", - "updates": [ - { - "q": { - "_id": 1 - }, - "u": [ - { - "$replaceRoot": { - "newRoot": "$t" - } - }, - { - "$addFields": { - "foo": 1 - } - } - ], - "multi": { - "$$unsetOrMatches": false - }, - "upsert": { - "$$unsetOrMatches": false - } - } - ] - }, - "commandName": "update", - "databaseName": "crud-tests" - } - } - ] - } - ], - "outcome": [ - { - "collectionName": "test", - "databaseName": "crud-tests", - "documents": [ - { - "_id": 1, - "u": { - "v": 1 - }, - "foo": 1 - }, - { - "_id": 2, - "x": 2, - "y": 1 - } - ] - } - ] - }, - { - "description": "UpdateMany in bulk write using pipelines", - "operations": [ - { - "object": "collection0", - "name": "bulkWrite", - "arguments": { - "requests": [ - { - "updateMany": { - "filter": {}, - "update": [ - { - "$project": { - "x": 1 - } - }, - { - "$addFields": { - "foo": 1 - } - } - ] - } - } - ] - }, - "expectResult": { - "matchedCount": 2, - "modifiedCount": 2, - "upsertedCount": 0 - } - } - ], - "expectEvents": [ - { - "client": "client0", - "events": [ - { - "commandStartedEvent": { - "command": { - "update": "test", - "updates": [ - { - "q": {}, - "u": [ - { - "$project": { - "x": 1 - } - }, - { - "$addFields": { - "foo": 1 - } - } - ], - "multi": true, - "upsert": { - "$$unsetOrMatches": false - } - } - ] - }, - "commandName": "update", - "databaseName": "crud-tests" - } - } - ] - } - ], - "outcome": [ - { - "collectionName": "test", - "databaseName": "crud-tests", - "documents": [ - { - "_id": 1, - "x": 1, - "foo": 1 - }, - { - "_id": 2, - "x": 2, - "foo": 1 - } - ] - } - ] - } - ] -} diff --git a/test/csot/runCursorCommand.json b/test/csot/runCursorCommand.json new file mode 100644 index 0000000000..36f774fb5a --- /dev/null +++ b/test/csot/runCursorCommand.json @@ -0,0 +1,583 @@ +{ + "description": "runCursorCommand", + "schemaVersion": "1.9", + "runOnRequirements": [ + { + "minServerVersion": "4.4" + } + ], + "createEntities": [ + { + "client": { + "id": "failPointClient", + "useMultipleMongoses": false + } + }, + { + "client": { + "id": "commandClient", + "useMultipleMongoses": false, + "observeEvents": [ + "commandStartedEvent", + "commandSucceededEvent" + ] + } + }, + { + "client": { + "id": "client", + "useMultipleMongoses": false, + "observeEvents": [ + "commandStartedEvent" + ], + "ignoreCommandMonitoringEvents": [ + "killCursors" + ] + } + }, + { + "database": { + "id": "commandDb", + "client": "commandClient", + "databaseName": "commandDb" + } + }, + { + "database": { + "id": "db", + "client": "client", + "databaseName": "db" + } + }, + { + "collection": { + "id": "collection", + "database": "db", + "collectionName": "collection" + } + } + ], + "initialData": [ + { + "collectionName": "collection", + "databaseName": "db", + "documents": [ + { + "_id": 1, + "x": 11 + }, + { + "_id": 2, + "x": 22 + }, + { + "_id": 3, + "x": 33 + }, + { + "_id": 4, + "x": 44 + }, + { + "_id": 5, + "x": 55 + } + ] + } + ], + "tests": [ + { + "description": "errors if timeoutMode is set without timeoutMS", + "operations": [ + { + "name": "runCursorCommand", + "object": "db", + "arguments": { + "commandName": "find", + "command": { + "find": "collection" + }, + "timeoutMode": "cursorLifetime" + }, + "expectError": { + "isClientError": true + } + } + ] + }, + { + "description": "error if timeoutMode is cursorLifetime and cursorType is tailableAwait", + "operations": [ + { + "name": "runCursorCommand", + "object": "db", + "arguments": { + "commandName": "find", + "command": { + "find": "collection" + }, + "timeoutMode": "cursorLifetime", + "cursorType": "tailableAwait" + }, + "expectError": { + "isClientError": true + } + } + ] + }, + { + "description": "Non-tailable cursor lifetime remaining timeoutMS applied to getMore if timeoutMode is unset", + "runOnRequirements": [ + { + "serverless": "forbid" + } + ], + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "failPointClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "find", + "getMore" + ], + "blockConnection": true, + "blockTimeMS": 60 + } + } + } + }, + { + "name": "runCursorCommand", + "object": "db", + "arguments": { + "commandName": "find", + "timeoutMS": 100, + "command": { + "find": "collection", + "batchSize": 2 + } + }, + "expectError": { + "isTimeoutError": true + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "find", + "command": { + "find": "collection", + "maxTimeMS": { + "$$type": [ + "int", + "long" + ] + } + } + } + }, + { + "commandStartedEvent": { + "commandName": "getMore", + "command": { + "getMore": { + "$$type": [ + "int", + "long" + ] + }, + "collection": "collection", + "maxTimeMS": { + "$$exists": false + } + } + } + } + ] + } + ] + }, + { + "description": "Non-tailable cursor iteration timeoutMS is refreshed for getMore if timeoutMode is iteration - failure", + "runOnRequirements": [ + { + "serverless": "forbid" + } + ], + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "failPointClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "getMore" + ], + "blockConnection": true, + "blockTimeMS": 60 + } + } + } + }, + { + "name": "runCursorCommand", + "object": "db", + "arguments": { + "commandName": "find", + "command": { + "find": "collection", + "batchSize": 2 + }, + "timeoutMode": "iteration", + "timeoutMS": 100, + "batchSize": 2 + }, + "expectError": { + "isTimeoutError": true + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "find", + "databaseName": "db", + "command": { + "find": "collection", + "maxTimeMS": { + "$$exists": false + } + } + } + }, + { + "commandStartedEvent": { + "commandName": "getMore", + "databaseName": "db", + "command": { + "getMore": { + "$$type": [ + "int", + "long" + ] + }, + "collection": "collection", + "maxTimeMS": { + "$$exists": false + } + } + } + } + ] + } + ] + }, + { + "description": "Tailable cursor iteration timeoutMS is refreshed for getMore - failure", + "runOnRequirements": [ + { + "serverless": "forbid" + } + ], + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "failPointClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "getMore" + ], + "blockConnection": true, + "blockTimeMS": 60 + } + } + } + }, + { + "name": "dropCollection", + "object": "db", + "arguments": { + "collection": "cappedCollection" + } + }, + { + "name": "createCollection", + "object": "db", + "arguments": { + "collection": "cappedCollection", + "capped": true, + "size": 4096, + "max": 3 + }, + "saveResultAsEntity": "cappedCollection" + }, + { + "name": "insertMany", + "object": "cappedCollection", + "arguments": { + "documents": [ + { + "_id": 1, + "x": 11 + }, + { + "_id": 2, + "x": 22 + } + ] + } + }, + { + "name": "createCommandCursor", + "object": "db", + "arguments": { + "commandName": "find", + "command": { + "find": "cappedCollection", + "batchSize": 1, + "tailable": true + }, + "timeoutMode": "iteration", + "timeoutMS": 100, + "batchSize": 1, + "cursorType": "tailable" + }, + "saveResultAsEntity": "tailableCursor" + }, + { + "name": "iterateUntilDocumentOrError", + "object": "tailableCursor" + }, + { + "name": "iterateUntilDocumentOrError", + "object": "tailableCursor", + "expectError": { + "isTimeoutError": true + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "drop" + } + }, + { + "commandStartedEvent": { + "commandName": "create" + } + }, + { + "commandStartedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "commandName": "find", + "databaseName": "db", + "command": { + "find": "cappedCollection", + "tailable": true, + "awaitData": { + "$$exists": false + }, + "maxTimeMS": { + "$$exists": false + } + } + } + }, + { + "commandStartedEvent": { + "commandName": "getMore", + "databaseName": "db", + "command": { + "getMore": { + "$$type": [ + "int", + "long" + ] + }, + "collection": "cappedCollection", + "maxTimeMS": { + "$$exists": false + } + } + } + } + ] + } + ] + }, + { + "description": "Tailable cursor awaitData iteration timeoutMS is refreshed for getMore - failure", + "runOnRequirements": [ + { + "serverless": "forbid" + } + ], + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "failPointClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "getMore" + ], + "blockConnection": true, + "blockTimeMS": 60 + } + } + } + }, + { + "name": "dropCollection", + "object": "db", + "arguments": { + "collection": "cappedCollection" + } + }, + { + "name": "createCollection", + "object": "db", + "arguments": { + "collection": "cappedCollection", + "capped": true, + "size": 4096, + "max": 3 + }, + "saveResultAsEntity": "cappedCollection" + }, + { + "name": "insertMany", + "object": "cappedCollection", + "arguments": { + "documents": [ + { + "foo": "bar" + }, + { + "fizz": "buzz" + } + ] + } + }, + { + "name": "createCommandCursor", + "object": "db", + "arguments": { + "command": { + "find": "cappedCollection", + "tailable": true, + "awaitData": true + }, + "cursorType": "tailableAwait", + "batchSize": 1 + }, + "saveResultAsEntity": "tailableCursor" + }, + { + "name": "iterateUntilDocumentOrError", + "object": "tailableCursor" + }, + { + "name": "iterateUntilDocumentOrError", + "object": "tailableCursor", + "expectError": { + "isTimeoutError": true + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "drop" + } + }, + { + "commandStartedEvent": { + "commandName": "create" + } + }, + { + "commandStartedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "commandName": "find", + "databaseName": "db", + "command": { + "find": "cappedCollection", + "tailable": true, + "awaitData": true, + "maxTimeMS": { + "$$exists": true + } + } + } + }, + { + "commandStartedEvent": { + "commandName": "getMore", + "databaseName": "db", + "command": { + "getMore": { + "$$type": [ + "int", + "long" + ] + }, + "collection": "cappedCollection" + } + } + } + ] + } + ] + } + ] +} diff --git a/test/csot/tailable-awaitData.json b/test/csot/tailable-awaitData.json index 535fb69243..81683d3993 100644 --- a/test/csot/tailable-awaitData.json +++ b/test/csot/tailable-awaitData.json @@ -3,7 +3,8 @@ "schemaVersion": "1.9", "runOnRequirements": [ { - "minServerVersion": "4.4" + "minServerVersion": "4.4", + "serverless": "forbid" } ], "createEntities": [ @@ -417,6 +418,141 @@ ] } ] + }, + { + "description": "apply remaining timeoutMS if less than maxAwaitTimeMS", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "failPointClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "getMore" + ], + "blockConnection": true, + "blockTimeMS": 30 + } + } + } + }, + { + "name": "createFindCursor", + "object": "collection", + "arguments": { + "filter": { + "_id": 1 + }, + "cursorType": "tailableAwait", + "batchSize": 1, + "maxAwaitTimeMS": 100, + "timeoutMS": 200 + }, + "saveResultAsEntity": "tailableCursor" + }, + { + "name": "iterateOnce", + "object": "tailableCursor" + }, + { + "name": "iterateUntilDocumentOrError", + "object": "tailableCursor", + "expectError": { + "isTimeoutError": true + } + } + ], + "expectEvents": [ + { + "client": "client", + "ignoreExtraEvents": true, + "events": [ + { + "commandStartedEvent": { + "commandName": "find", + "databaseName": "test" + } + }, + { + "commandStartedEvent": { + "commandName": "getMore", + "databaseName": "test", + "command": { + "maxTimeMS": { + "$$lte": 100 + } + } + } + }, + { + "commandStartedEvent": { + "commandName": "getMore", + "databaseName": "test", + "command": { + "maxTimeMS": { + "$$lte": 70 + } + } + } + } + ] + } + ] + }, + { + "description": "apply maxAwaitTimeMS if less than remaining timeout", + "operations": [ + { + "name": "createFindCursor", + "object": "collection", + "arguments": { + "filter": {}, + "cursorType": "tailableAwait", + "batchSize": 1, + "maxAwaitTimeMS": 100, + "timeoutMS": 200 + }, + "saveResultAsEntity": "tailableCursor" + }, + { + "name": "iterateOnce", + "object": "tailableCursor" + }, + { + "name": "iterateOnce", + "object": "tailableCursor" + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "find", + "databaseName": "test" + } + }, + { + "commandStartedEvent": { + "commandName": "getMore", + "databaseName": "test", + "command": { + "maxTimeMS": { + "$$lte": 100 + } + } + } + } + ] + } + ] } ] } diff --git a/test/csot/waitQueueTimeout.json b/test/csot/waitQueueTimeout.json new file mode 100644 index 0000000000..138d5cc161 --- /dev/null +++ b/test/csot/waitQueueTimeout.json @@ -0,0 +1,176 @@ +{ + "description": "WaitQueueTimeoutError does not clear the pool", + "schemaVersion": "1.9", + "runOnRequirements": [ + { + "minServerVersion": "4.4", + "topologies": [ + "single", + "replicaset", + "sharded" + ] + } + ], + "createEntities": [ + { + "client": { + "id": "failPointClient", + "useMultipleMongoses": false + } + }, + { + "client": { + "id": "client", + "uriOptions": { + "maxPoolSize": 1, + "appname": "waitQueueTimeoutErrorTest" + }, + "useMultipleMongoses": false, + "observeEvents": [ + "commandStartedEvent", + "poolClearedEvent" + ] + } + }, + { + "database": { + "id": "database", + "client": "client", + "databaseName": "test" + } + } + ], + "tests": [ + { + "description": "WaitQueueTimeoutError does not clear the pool", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "failPointClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "ping" + ], + "blockConnection": true, + "blockTimeMS": 500, + "appName": "waitQueueTimeoutErrorTest" + } + } + } + }, + { + "name": "createEntities", + "object": "testRunner", + "arguments": { + "entities": [ + { + "thread": { + "id": "thread0" + } + } + ] + } + }, + { + "name": "runOnThread", + "object": "testRunner", + "arguments": { + "thread": "thread0", + "operation": { + "name": "runCommand", + "object": "database", + "arguments": { + "command": { + "ping": 1 + }, + "commandName": "ping" + } + } + } + }, + { + "name": "waitForEvent", + "object": "testRunner", + "arguments": { + "client": "client", + "event": { + "commandStartedEvent": { + "commandName": "ping" + } + }, + "count": 1 + } + }, + { + "name": "runCommand", + "object": "database", + "arguments": { + "timeoutMS": 100, + "command": { + "hello": 1 + }, + "commandName": "hello" + }, + "expectError": { + "isTimeoutError": true + } + }, + { + "name": "waitForThread", + "object": "testRunner", + "arguments": { + "thread": "thread0" + } + }, + { + "name": "runCommand", + "object": "database", + "arguments": { + "command": { + "hello": 1 + }, + "commandName": "hello" + } + } + ], + "expectEvents": [ + { + "client": "client", + "eventType": "command", + "events": [ + { + "commandStartedEvent": { + "commandName": "ping", + "databaseName": "test", + "command": { + "ping": 1 + } + } + }, + { + "commandStartedEvent": { + "commandName": "hello", + "databaseName": "test", + "command": { + "hello": 1 + } + } + } + ] + }, + { + "client": "client", + "eventType": "cmap", + "events": [] + } + ] + } + ] +} diff --git a/test/discovery_and_monitoring/rs/new_primary.json b/test/discovery_and_monitoring/rs/new_primary.json index 1a84c69c91..69b07516b9 100644 --- a/test/discovery_and_monitoring/rs/new_primary.json +++ b/test/discovery_and_monitoring/rs/new_primary.json @@ -58,7 +58,8 @@ "servers": { "a:27017": { "type": "Unknown", - "setName": null + "setName": null, + "error": "primary marked stale due to discovery of newer primary" }, "b:27017": { "type": "RSPrimary", diff --git a/test/discovery_and_monitoring/rs/new_primary_new_electionid.json b/test/discovery_and_monitoring/rs/new_primary_new_electionid.json index 509720d445..90ef0ce8dc 100644 --- a/test/discovery_and_monitoring/rs/new_primary_new_electionid.json +++ b/test/discovery_and_monitoring/rs/new_primary_new_electionid.json @@ -76,7 +76,8 @@ "a:27017": { "type": "Unknown", "setName": null, - "electionId": null + "electionId": null, + "error": "primary marked stale due to discovery of newer primary" }, "b:27017": { "type": "RSPrimary", @@ -123,7 +124,8 @@ "a:27017": { "type": "Unknown", "setName": null, - "electionId": null + "electionId": null, + "error": "primary marked stale due to electionId/setVersion mismatch" }, "b:27017": { "type": "RSPrimary", diff --git a/test/discovery_and_monitoring/rs/new_primary_new_setversion.json b/test/discovery_and_monitoring/rs/new_primary_new_setversion.json index 96533c61ee..9c1e2d4bdd 100644 --- a/test/discovery_and_monitoring/rs/new_primary_new_setversion.json +++ b/test/discovery_and_monitoring/rs/new_primary_new_setversion.json @@ -76,7 +76,8 @@ "a:27017": { "type": "Unknown", "setName": null, - "electionId": null + "electionId": null, + "error": "primary marked stale due to discovery of newer primary" }, "b:27017": { "type": "RSPrimary", @@ -123,7 +124,8 @@ "a:27017": { "type": "Unknown", "setName": null, - "electionId": null + "electionId": null, + "error": "primary marked stale due to electionId/setVersion mismatch" }, "b:27017": { "type": "RSPrimary", diff --git a/test/discovery_and_monitoring/rs/primary_disconnect_electionid.json b/test/discovery_and_monitoring/rs/primary_disconnect_electionid.json index 5a91188ea8..b030bd2c53 100644 --- a/test/discovery_and_monitoring/rs/primary_disconnect_electionid.json +++ b/test/discovery_and_monitoring/rs/primary_disconnect_electionid.json @@ -48,7 +48,8 @@ "a:27017": { "type": "Unknown", "setName": null, - "electionId": null + "electionId": null, + "error": "primary marked stale due to discovery of newer primary" }, "b:27017": { "type": "RSPrimary", @@ -124,6 +125,7 @@ "a:27017": { "type": "Unknown", "setName": null, + "error": "primary marked stale due to electionId/setVersion mismatch", "electionId": null }, "b:27017": { diff --git a/test/discovery_and_monitoring/rs/primary_disconnect_setversion.json b/test/discovery_and_monitoring/rs/primary_disconnect_setversion.json index f7417ad77b..653a5f29e8 100644 --- a/test/discovery_and_monitoring/rs/primary_disconnect_setversion.json +++ b/test/discovery_and_monitoring/rs/primary_disconnect_setversion.json @@ -48,7 +48,8 @@ "a:27017": { "type": "Unknown", "setName": null, - "electionId": null + "electionId": null, + "error": "primary marked stale due to discovery of newer primary" }, "b:27017": { "type": "RSPrimary", @@ -124,6 +125,7 @@ "a:27017": { "type": "Unknown", "setName": null, + "error": "primary marked stale due to electionId/setVersion mismatch", "electionId": null }, "b:27017": { diff --git a/test/discovery_and_monitoring/rs/secondary_ipv6_literal.json b/test/discovery_and_monitoring/rs/secondary_ipv6_literal.json new file mode 100644 index 0000000000..c23d8dc4c9 --- /dev/null +++ b/test/discovery_and_monitoring/rs/secondary_ipv6_literal.json @@ -0,0 +1,38 @@ +{ + "description": "Secondary with IPv6 literal", + "uri": "mongodb://[::1]/?replicaSet=rs", + "phases": [ + { + "responses": [ + [ + "[::1]:27017", + { + "ok": 1, + "helloOk": true, + "isWritablePrimary": false, + "secondary": true, + "setName": "rs", + "me": "[::1]:27017", + "hosts": [ + "[::1]:27017" + ], + "minWireVersion": 0, + "maxWireVersion": 26 + } + ] + ], + "outcome": { + "servers": { + "[::1]:27017": { + "type": "RSSecondary", + "setName": "rs" + } + }, + "topologyType": "ReplicaSetNoPrimary", + "setName": "rs", + "logicalSessionTimeoutMinutes": null, + "compatible": true + } + } + ] +} diff --git a/test/discovery_and_monitoring/rs/setversion_greaterthan_max_without_electionid.json b/test/discovery_and_monitoring/rs/setversion_greaterthan_max_without_electionid.json index 97870d71d5..06c89609f5 100644 --- a/test/discovery_and_monitoring/rs/setversion_greaterthan_max_without_electionid.json +++ b/test/discovery_and_monitoring/rs/setversion_greaterthan_max_without_electionid.json @@ -65,7 +65,8 @@ "a:27017": { "type": "Unknown", "setName": null, - "electionId": null + "electionId": null, + "error": "primary marked stale due to discovery of newer primary" }, "b:27017": { "type": "RSPrimary", diff --git a/test/discovery_and_monitoring/rs/setversion_without_electionid-pre-6.0.json b/test/discovery_and_monitoring/rs/setversion_without_electionid-pre-6.0.json index e62c6963ed..87029e578b 100644 --- a/test/discovery_and_monitoring/rs/setversion_without_electionid-pre-6.0.json +++ b/test/discovery_and_monitoring/rs/setversion_without_electionid-pre-6.0.json @@ -65,7 +65,8 @@ "a:27017": { "type": "Unknown", "setName": null, - "electionId": null + "electionId": null, + "error": "primary marked stale due to discovery of newer primary" }, "b:27017": { "type": "RSPrimary", diff --git a/test/discovery_and_monitoring/rs/use_setversion_without_electionid-pre-6.0.json b/test/discovery_and_monitoring/rs/use_setversion_without_electionid-pre-6.0.json index 2f9b567b85..a63efeac12 100644 --- a/test/discovery_and_monitoring/rs/use_setversion_without_electionid-pre-6.0.json +++ b/test/discovery_and_monitoring/rs/use_setversion_without_electionid-pre-6.0.json @@ -73,7 +73,8 @@ "a:27017": { "type": "Unknown", "setName": null, - "electionId": null + "electionId": null, + "error": "primary marked stale due to discovery of newer primary" }, "b:27017": { "type": "RSPrimary", @@ -117,7 +118,8 @@ "a:27017": { "type": "Unknown", "setName": null, - "electionId": null + "electionId": null, + "error": "primary marked stale due to electionId/setVersion mismatch" }, "b:27017": { "type": "RSPrimary", diff --git a/test/discovery_and_monitoring/rs/use_setversion_without_electionid.json b/test/discovery_and_monitoring/rs/use_setversion_without_electionid.json index 551f3e12c2..eaf586d728 100644 --- a/test/discovery_and_monitoring/rs/use_setversion_without_electionid.json +++ b/test/discovery_and_monitoring/rs/use_setversion_without_electionid.json @@ -81,7 +81,8 @@ "b:27017": { "type": "Unknown", "setName": null, - "electionId": null + "electionId": null, + "error": "primary marked stale due to electionId/setVersion mismatch" } }, "topologyType": "ReplicaSetWithPrimary", @@ -128,7 +129,8 @@ "b:27017": { "type": "Unknown", "setName": null, - "electionId": null + "electionId": null, + "error": "primary marked stale due to electionId/setVersion mismatch" } }, "topologyType": "ReplicaSetWithPrimary", diff --git a/test/discovery_and_monitoring/unified/serverMonitoringMode.json b/test/discovery_and_monitoring/unified/serverMonitoringMode.json index 4b492f7d85..e44fad1bcd 100644 --- a/test/discovery_and_monitoring/unified/serverMonitoringMode.json +++ b/test/discovery_and_monitoring/unified/serverMonitoringMode.json @@ -5,8 +5,7 @@ { "topologies": [ "single", - "sharded", - "sharded-replicaset" + "sharded" ], "serverless": "forbid" } diff --git a/test/gridfs/delete.json b/test/gridfs/delete.json index 7a4ec27f88..277b9ed7e1 100644 --- a/test/gridfs/delete.json +++ b/test/gridfs/delete.json @@ -49,10 +49,7 @@ "uploadDate": { "$date": "1970-01-01T00:00:00.000Z" }, - "md5": "d41d8cd98f00b204e9800998ecf8427e", "filename": "length-0", - "contentType": "application/octet-stream", - "aliases": [], "metadata": {} }, { @@ -64,10 +61,7 @@ "uploadDate": { "$date": "1970-01-01T00:00:00.000Z" }, - "md5": "d41d8cd98f00b204e9800998ecf8427e", "filename": "length-0-with-empty-chunk", - "contentType": "application/octet-stream", - "aliases": [], "metadata": {} }, { @@ -79,10 +73,7 @@ "uploadDate": { "$date": "1970-01-01T00:00:00.000Z" }, - "md5": "c700ed4fdb1d27055aa3faa2c2432283", "filename": "length-2", - "contentType": "application/octet-stream", - "aliases": [], "metadata": {} }, { @@ -94,10 +85,7 @@ "uploadDate": { "$date": "1970-01-01T00:00:00.000Z" }, - "md5": "dd254cdc958e53abaa67da9f797125f5", "filename": "length-8", - "contentType": "application/octet-stream", - "aliases": [], "metadata": {} } ] @@ -197,10 +185,7 @@ "uploadDate": { "$date": "1970-01-01T00:00:00.000Z" }, - "md5": "d41d8cd98f00b204e9800998ecf8427e", "filename": "length-0-with-empty-chunk", - "contentType": "application/octet-stream", - "aliases": [], "metadata": {} }, { @@ -212,10 +197,7 @@ "uploadDate": { "$date": "1970-01-01T00:00:00.000Z" }, - "md5": "c700ed4fdb1d27055aa3faa2c2432283", "filename": "length-2", - "contentType": "application/octet-stream", - "aliases": [], "metadata": {} }, { @@ -227,10 +209,7 @@ "uploadDate": { "$date": "1970-01-01T00:00:00.000Z" }, - "md5": "dd254cdc958e53abaa67da9f797125f5", "filename": "length-8", - "contentType": "application/octet-stream", - "aliases": [], "metadata": {} } ] @@ -330,10 +309,7 @@ "uploadDate": { "$date": "1970-01-01T00:00:00.000Z" }, - "md5": "d41d8cd98f00b204e9800998ecf8427e", "filename": "length-0", - "contentType": "application/octet-stream", - "aliases": [], "metadata": {} }, { @@ -345,10 +321,7 @@ "uploadDate": { "$date": "1970-01-01T00:00:00.000Z" }, - "md5": "c700ed4fdb1d27055aa3faa2c2432283", "filename": "length-2", - "contentType": "application/octet-stream", - "aliases": [], "metadata": {} }, { @@ -360,10 +333,7 @@ "uploadDate": { "$date": "1970-01-01T00:00:00.000Z" }, - "md5": "dd254cdc958e53abaa67da9f797125f5", "filename": "length-8", - "contentType": "application/octet-stream", - "aliases": [], "metadata": {} } ] @@ -448,10 +418,7 @@ "uploadDate": { "$date": "1970-01-01T00:00:00.000Z" }, - "md5": "d41d8cd98f00b204e9800998ecf8427e", "filename": "length-0", - "contentType": "application/octet-stream", - "aliases": [], "metadata": {} }, { @@ -463,10 +430,7 @@ "uploadDate": { "$date": "1970-01-01T00:00:00.000Z" }, - "md5": "d41d8cd98f00b204e9800998ecf8427e", "filename": "length-0-with-empty-chunk", - "contentType": "application/octet-stream", - "aliases": [], "metadata": {} }, { @@ -478,10 +442,7 @@ "uploadDate": { "$date": "1970-01-01T00:00:00.000Z" }, - "md5": "c700ed4fdb1d27055aa3faa2c2432283", "filename": "length-2", - "contentType": "application/octet-stream", - "aliases": [], "metadata": {} } ] @@ -554,10 +515,7 @@ "uploadDate": { "$date": "1970-01-01T00:00:00.000Z" }, - "md5": "d41d8cd98f00b204e9800998ecf8427e", "filename": "length-0", - "contentType": "application/octet-stream", - "aliases": [], "metadata": {} }, { @@ -569,10 +527,7 @@ "uploadDate": { "$date": "1970-01-01T00:00:00.000Z" }, - "md5": "d41d8cd98f00b204e9800998ecf8427e", "filename": "length-0-with-empty-chunk", - "contentType": "application/octet-stream", - "aliases": [], "metadata": {} }, { @@ -584,10 +539,7 @@ "uploadDate": { "$date": "1970-01-01T00:00:00.000Z" }, - "md5": "c700ed4fdb1d27055aa3faa2c2432283", "filename": "length-2", - "contentType": "application/octet-stream", - "aliases": [], "metadata": {} }, { @@ -599,10 +551,7 @@ "uploadDate": { "$date": "1970-01-01T00:00:00.000Z" }, - "md5": "dd254cdc958e53abaa67da9f797125f5", "filename": "length-8", - "contentType": "application/octet-stream", - "aliases": [], "metadata": {} } ] @@ -719,10 +668,7 @@ "uploadDate": { "$date": "1970-01-01T00:00:00.000Z" }, - "md5": "d41d8cd98f00b204e9800998ecf8427e", "filename": "length-0", - "contentType": "application/octet-stream", - "aliases": [], "metadata": {} }, { @@ -734,10 +680,7 @@ "uploadDate": { "$date": "1970-01-01T00:00:00.000Z" }, - "md5": "d41d8cd98f00b204e9800998ecf8427e", "filename": "length-0-with-empty-chunk", - "contentType": "application/octet-stream", - "aliases": [], "metadata": {} }, { @@ -749,10 +692,7 @@ "uploadDate": { "$date": "1970-01-01T00:00:00.000Z" }, - "md5": "c700ed4fdb1d27055aa3faa2c2432283", "filename": "length-2", - "contentType": "application/octet-stream", - "aliases": [], "metadata": {} } ] diff --git a/test/gridfs/deleteByName.json b/test/gridfs/deleteByName.json new file mode 100644 index 0000000000..884d0300ce --- /dev/null +++ b/test/gridfs/deleteByName.json @@ -0,0 +1,230 @@ +{ + "description": "gridfs-deleteByName", + "schemaVersion": "1.0", + "createEntities": [ + { + "client": { + "id": "client0" + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "gridfs-tests" + } + }, + { + "bucket": { + "id": "bucket0", + "database": "database0" + } + }, + { + "collection": { + "id": "bucket0_files_collection", + "database": "database0", + "collectionName": "fs.files" + } + }, + { + "collection": { + "id": "bucket0_chunks_collection", + "database": "database0", + "collectionName": "fs.chunks" + } + } + ], + "initialData": [ + { + "collectionName": "fs.files", + "databaseName": "gridfs-tests", + "documents": [ + { + "_id": { + "$oid": "000000000000000000000001" + }, + "length": 0, + "chunkSize": 4, + "uploadDate": { + "$date": "1970-01-01T00:00:00.000Z" + }, + "filename": "filename", + "metadata": {} + }, + { + "_id": { + "$oid": "000000000000000000000002" + }, + "length": 0, + "chunkSize": 4, + "uploadDate": { + "$date": "1970-01-01T00:00:00.000Z" + }, + "filename": "filename", + "metadata": {} + }, + { + "_id": { + "$oid": "000000000000000000000003" + }, + "length": 2, + "chunkSize": 4, + "uploadDate": { + "$date": "1970-01-01T00:00:00.000Z" + }, + "filename": "filename", + "metadata": {} + }, + { + "_id": { + "$oid": "000000000000000000000004" + }, + "length": 8, + "chunkSize": 4, + "uploadDate": { + "$date": "1970-01-01T00:00:00.000Z" + }, + "filename": "otherfilename", + "metadata": {} + } + ] + }, + { + "collectionName": "fs.chunks", + "databaseName": "gridfs-tests", + "documents": [ + { + "_id": { + "$oid": "000000000000000000000001" + }, + "files_id": { + "$oid": "000000000000000000000002" + }, + "n": 0, + "data": { + "$binary": { + "base64": "", + "subType": "00" + } + } + }, + { + "_id": { + "$oid": "000000000000000000000002" + }, + "files_id": { + "$oid": "000000000000000000000003" + }, + "n": 0, + "data": { + "$binary": { + "base64": "", + "subType": "00" + } + } + }, + { + "_id": { + "$oid": "000000000000000000000003" + }, + "files_id": { + "$oid": "000000000000000000000003" + }, + "n": 0, + "data": { + "$binary": { + "base64": "", + "subType": "00" + } + } + }, + { + "_id": { + "$oid": "000000000000000000000004" + }, + "files_id": { + "$oid": "000000000000000000000004" + }, + "n": 0, + "data": { + "$binary": { + "base64": "", + "subType": "00" + } + } + } + ] + } + ], + "tests": [ + { + "description": "delete when multiple revisions of the file exist", + "operations": [ + { + "name": "deleteByName", + "object": "bucket0", + "arguments": { + "filename": "filename" + } + } + ], + "outcome": [ + { + "collectionName": "fs.files", + "databaseName": "gridfs-tests", + "documents": [ + { + "_id": { + "$oid": "000000000000000000000004" + }, + "length": 8, + "chunkSize": 4, + "uploadDate": { + "$date": "1970-01-01T00:00:00.000Z" + }, + "filename": "otherfilename", + "metadata": {} + } + ] + }, + { + "collectionName": "fs.chunks", + "databaseName": "gridfs-tests", + "documents": [ + { + "_id": { + "$oid": "000000000000000000000004" + }, + "files_id": { + "$oid": "000000000000000000000004" + }, + "n": 0, + "data": { + "$binary": { + "base64": "", + "subType": "00" + } + } + } + ] + } + ] + }, + { + "description": "delete when file name does not exist", + "operations": [ + { + "name": "deleteByName", + "object": "bucket0", + "arguments": { + "filename": "missing-file" + }, + "expectError": { + "isClientError": true + } + } + ] + } + ] +} diff --git a/test/gridfs/download.json b/test/gridfs/download.json index 48d3246218..f0cb851708 100644 --- a/test/gridfs/download.json +++ b/test/gridfs/download.json @@ -49,10 +49,7 @@ "uploadDate": { "$date": "1970-01-01T00:00:00.000Z" }, - "md5": "d41d8cd98f00b204e9800998ecf8427e", "filename": "length-0", - "contentType": "application/octet-stream", - "aliases": [], "metadata": {} }, { @@ -64,10 +61,7 @@ "uploadDate": { "$date": "1970-01-01T00:00:00.000Z" }, - "md5": "d41d8cd98f00b204e9800998ecf8427e", "filename": "length-0-with-empty-chunk", - "contentType": "application/octet-stream", - "aliases": [], "metadata": {} }, { @@ -79,10 +73,7 @@ "uploadDate": { "$date": "1970-01-01T00:00:00.000Z" }, - "md5": "c700ed4fdb1d27055aa3faa2c2432283", "filename": "length-2", - "contentType": "application/octet-stream", - "aliases": [], "metadata": {} }, { @@ -94,10 +85,7 @@ "uploadDate": { "$date": "1970-01-01T00:00:00.000Z" }, - "md5": "dd254cdc958e53abaa67da9f797125f5", "filename": "length-8", - "contentType": "application/octet-stream", - "aliases": [], "metadata": {} }, { @@ -109,10 +97,7 @@ "uploadDate": { "$date": "1970-01-01T00:00:00.000Z" }, - "md5": "57d83cd477bfb1ccd975ab33d827a92b", "filename": "length-10", - "contentType": "application/octet-stream", - "aliases": [], "metadata": {} }, { @@ -124,9 +109,6 @@ "uploadDate": { "$date": "1970-01-01T00:00:00.000Z" }, - "md5": "c700ed4fdb1d27055aa3faa2c2432283", - "contentType": "application/octet-stream", - "aliases": [], "metadata": {} } ] diff --git a/test/gridfs/downloadByName.json b/test/gridfs/downloadByName.json index cd44663957..7b20933c16 100644 --- a/test/gridfs/downloadByName.json +++ b/test/gridfs/downloadByName.json @@ -49,10 +49,7 @@ "uploadDate": { "$date": "1970-01-01T00:00:00.000Z" }, - "md5": "47ed733b8d10be225eceba344d533586", "filename": "abc", - "contentType": "application/octet-stream", - "aliases": [], "metadata": {} }, { @@ -64,10 +61,7 @@ "uploadDate": { "$date": "1970-01-02T00:00:00.000Z" }, - "md5": "b15835f133ff2e27c7cb28117bfae8f4", "filename": "abc", - "contentType": "application/octet-stream", - "aliases": [], "metadata": {} }, { @@ -79,10 +73,7 @@ "uploadDate": { "$date": "1970-01-03T00:00:00.000Z" }, - "md5": "eccbc87e4b5ce2fe28308fd9f2a7baf3", "filename": "abc", - "contentType": "application/octet-stream", - "aliases": [], "metadata": {} }, { @@ -94,10 +85,7 @@ "uploadDate": { "$date": "1970-01-04T00:00:00.000Z" }, - "md5": "f623e75af30e62bbd73d6df5b50bb7b5", "filename": "abc", - "contentType": "application/octet-stream", - "aliases": [], "metadata": {} }, { @@ -109,10 +97,7 @@ "uploadDate": { "$date": "1970-01-05T00:00:00.000Z" }, - "md5": "4c614360da93c0a041b22e537de151eb", "filename": "abc", - "contentType": "application/octet-stream", - "aliases": [], "metadata": {} } ] diff --git a/test/gridfs/renameByName.json b/test/gridfs/renameByName.json new file mode 100644 index 0000000000..26f04fb9e0 --- /dev/null +++ b/test/gridfs/renameByName.json @@ -0,0 +1,313 @@ +{ + "description": "gridfs-renameByName", + "schemaVersion": "1.0", + "createEntities": [ + { + "client": { + "id": "client0" + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "gridfs-tests" + } + }, + { + "bucket": { + "id": "bucket0", + "database": "database0" + } + }, + { + "collection": { + "id": "bucket0_files_collection", + "database": "database0", + "collectionName": "fs.files" + } + }, + { + "collection": { + "id": "bucket0_chunks_collection", + "database": "database0", + "collectionName": "fs.chunks" + } + } + ], + "initialData": [ + { + "collectionName": "fs.files", + "databaseName": "gridfs-tests", + "documents": [ + { + "_id": { + "$oid": "000000000000000000000001" + }, + "length": 0, + "chunkSize": 4, + "uploadDate": { + "$date": "1970-01-01T00:00:00.000Z" + }, + "filename": "filename", + "metadata": {} + }, + { + "_id": { + "$oid": "000000000000000000000002" + }, + "length": 0, + "chunkSize": 4, + "uploadDate": { + "$date": "1970-01-01T00:00:00.000Z" + }, + "filename": "filename", + "metadata": {} + }, + { + "_id": { + "$oid": "000000000000000000000003" + }, + "length": 2, + "chunkSize": 4, + "uploadDate": { + "$date": "1970-01-01T00:00:00.000Z" + }, + "filename": "filename", + "metadata": {} + }, + { + "_id": { + "$oid": "000000000000000000000004" + }, + "length": 8, + "chunkSize": 4, + "uploadDate": { + "$date": "1970-01-01T00:00:00.000Z" + }, + "filename": "otherfilename", + "metadata": {} + } + ] + }, + { + "collectionName": "fs.chunks", + "databaseName": "gridfs-tests", + "documents": [ + { + "_id": { + "$oid": "000000000000000000000001" + }, + "files_id": { + "$oid": "000000000000000000000002" + }, + "n": 0, + "data": { + "$binary": { + "base64": "", + "subType": "00" + } + } + }, + { + "_id": { + "$oid": "000000000000000000000002" + }, + "files_id": { + "$oid": "000000000000000000000003" + }, + "n": 0, + "data": { + "$binary": { + "base64": "", + "subType": "00" + } + } + }, + { + "_id": { + "$oid": "000000000000000000000003" + }, + "files_id": { + "$oid": "000000000000000000000004" + }, + "n": 0, + "data": { + "$binary": { + "base64": "", + "subType": "00" + } + } + }, + { + "_id": { + "$oid": "000000000000000000000004" + }, + "files_id": { + "$oid": "000000000000000000000004" + }, + "n": 1, + "data": { + "$binary": { + "base64": "", + "subType": "00" + } + } + } + ] + } + ], + "tests": [ + { + "description": "rename when multiple revisions of the file exist", + "operations": [ + { + "name": "renameByName", + "object": "bucket0", + "arguments": { + "filename": "filename", + "newFilename": "newfilename" + } + } + ], + "outcome": [ + { + "collectionName": "fs.files", + "databaseName": "gridfs-tests", + "documents": [ + { + "_id": { + "$oid": "000000000000000000000001" + }, + "length": 0, + "chunkSize": 4, + "uploadDate": { + "$date": "1970-01-01T00:00:00.000Z" + }, + "filename": "newfilename", + "metadata": {} + }, + { + "_id": { + "$oid": "000000000000000000000002" + }, + "length": 0, + "chunkSize": 4, + "uploadDate": { + "$date": "1970-01-01T00:00:00.000Z" + }, + "filename": "newfilename", + "metadata": {} + }, + { + "_id": { + "$oid": "000000000000000000000003" + }, + "length": 2, + "chunkSize": 4, + "uploadDate": { + "$date": "1970-01-01T00:00:00.000Z" + }, + "filename": "newfilename", + "metadata": {} + }, + { + "_id": { + "$oid": "000000000000000000000004" + }, + "length": 8, + "chunkSize": 4, + "uploadDate": { + "$date": "1970-01-01T00:00:00.000Z" + }, + "filename": "otherfilename", + "metadata": {} + } + ] + }, + { + "collectionName": "fs.chunks", + "databaseName": "gridfs-tests", + "documents": [ + { + "_id": { + "$oid": "000000000000000000000001" + }, + "files_id": { + "$oid": "000000000000000000000002" + }, + "n": 0, + "data": { + "$binary": { + "base64": "", + "subType": "00" + } + } + }, + { + "_id": { + "$oid": "000000000000000000000002" + }, + "files_id": { + "$oid": "000000000000000000000003" + }, + "n": 0, + "data": { + "$binary": { + "base64": "", + "subType": "00" + } + } + }, + { + "_id": { + "$oid": "000000000000000000000003" + }, + "files_id": { + "$oid": "000000000000000000000004" + }, + "n": 0, + "data": { + "$binary": { + "base64": "", + "subType": "00" + } + } + }, + { + "_id": { + "$oid": "000000000000000000000004" + }, + "files_id": { + "$oid": "000000000000000000000004" + }, + "n": 1, + "data": { + "$binary": { + "base64": "", + "subType": "00" + } + } + } + ] + } + ] + }, + { + "description": "rename when file name does not exist", + "operations": [ + { + "name": "renameByName", + "object": "bucket0", + "arguments": { + "filename": "missing-file", + "newFilename": "newfilename" + }, + "expectError": { + "isClientError": true + } + } + ] + } + ] +} diff --git a/test/gridfs/upload.json b/test/gridfs/upload.json index 97e18d2bc2..3c1644653a 100644 --- a/test/gridfs/upload.json +++ b/test/gridfs/upload.json @@ -470,75 +470,6 @@ } ] }, - { - "description": "upload when contentType is provided", - "operations": [ - { - "name": "upload", - "object": "bucket0", - "arguments": { - "filename": "filename", - "source": { - "$$hexBytes": "11" - }, - "chunkSizeBytes": 4, - "contentType": "image/jpeg" - }, - "expectResult": { - "$$type": "objectId" - }, - "saveResultAsEntity": "uploadedObjectId" - }, - { - "name": "find", - "object": "bucket0_files_collection", - "arguments": { - "filter": {} - }, - "expectResult": [ - { - "_id": { - "$$matchesEntity": "uploadedObjectId" - }, - "length": 1, - "chunkSize": 4, - "uploadDate": { - "$$type": "date" - }, - "md5": { - "$$unsetOrMatches": "47ed733b8d10be225eceba344d533586" - }, - "filename": "filename", - "contentType": "image/jpeg" - } - ] - }, - { - "name": "find", - "object": "bucket0_chunks_collection", - "arguments": { - "filter": {} - }, - "expectResult": [ - { - "_id": { - "$$type": "objectId" - }, - "files_id": { - "$$matchesEntity": "uploadedObjectId" - }, - "n": 0, - "data": { - "$binary": { - "base64": "EQ==", - "subType": "00" - } - } - } - ] - } - ] - }, { "description": "upload when metadata is provided", "operations": [ diff --git a/test/helpers.py b/test/helpers.py index 11d5ab0374..12c55ade1b 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -15,6 +15,7 @@ """Shared constants and helper methods for pymongo, bson, and gridfs test suites.""" from __future__ import annotations +import asyncio import base64 import gc import multiprocessing @@ -30,6 +31,8 @@ import warnings from asyncio import iscoroutinefunction +from pymongo._asyncio_task import create_task + try: import ipaddress @@ -37,14 +40,14 @@ except ImportError: HAVE_IPADDRESS = False from functools import wraps -from typing import Any, Callable, Dict, Generator, no_type_check +from typing import Any, Callable, Dict, Generator, Optional, no_type_check from unittest import SkipTest from bson.son import SON from pymongo import common, message from pymongo.read_preferences import ReadPreference from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] -from pymongo.uri_parser import parse_uri +from pymongo.synchronous.uri_parser import parse_uri if HAVE_SSL: import ssl @@ -78,7 +81,7 @@ COMPRESSORS = os.environ.get("COMPRESSORS") MONGODB_API_VERSION = os.environ.get("MONGODB_API_VERSION") -TEST_LOADBALANCER = bool(os.environ.get("TEST_LOADBALANCER")) +TEST_LOADBALANCER = bool(os.environ.get("TEST_LOAD_BALANCER")) TEST_SERVERLESS = bool(os.environ.get("TEST_SERVERLESS")) SINGLE_MONGOS_LB_URI = os.environ.get("SINGLE_MONGOS_LB_URI") MULTI_MONGOS_LB_URI = os.environ.get("MULTI_MONGOS_LB_URI") @@ -369,3 +372,53 @@ def disable(self): os.environ.pop("SSL_CERT_FILE") else: os.environ["SSL_CERT_FILE"] = self.original_certs + + +if _IS_SYNC: + PARENT = threading.Thread +else: + PARENT = object + + +class ConcurrentRunner(PARENT): + def __init__(self, **kwargs): + if _IS_SYNC: + super().__init__(**kwargs) + self.name = kwargs.get("name", "ConcurrentRunner") + self.stopped = False + self.task = None + self.target = kwargs.get("target", None) + self.args = kwargs.get("args", []) + + if not _IS_SYNC: + + def start(self): + self.task = create_task(self.run(), name=self.name) + + def join(self, timeout: Optional[float] = None): # type: ignore[override] + if self.task is not None: + asyncio.wait([self.task], timeout=timeout) + + def is_alive(self): + return not self.stopped + + def run(self): + try: + self.target(*self.args) + finally: + self.stopped = True + + +class ExceptionCatchingTask(ConcurrentRunner): + """A Task that stores any exception encountered while running.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.exc = None + + def run(self): + try: + super().run() + except BaseException as exc: + self.exc = exc + raise diff --git a/test/lambda/build.sh b/test/lambda/build.sh deleted file mode 100755 index c7cc24eab2..0000000000 --- a/test/lambda/build.sh +++ /dev/null @@ -1,28 +0,0 @@ -#!/bin/bash -set -o errexit # Exit the script with error if any of the commands fail -set -o xtrace - -rm -rf mongodb/pymongo -rm -rf mongodb/gridfs -rm -rf mongodb/bson - -pushd ../.. -rm -f pymongo/*.so -rm -f bson/*.so -image="quay.io/pypa/manylinux2014_x86_64:latest" - -DOCKER=$(command -v docker) || true -if [ -z "$DOCKER" ]; then - PODMAN=$(command -v podman) || true - if [ -z "$PODMAN" ]; then - echo "docker or podman are required!" - exit 1 - fi - DOCKER=podman -fi - -$DOCKER run --rm -v "`pwd`:/src" $image /src/test/lambda/build_internal.sh -cp -r pymongo ./test/lambda/mongodb/pymongo -cp -r bson ./test/lambda/mongodb/bson -cp -r gridfs ./test/lambda/mongodb/gridfs -popd diff --git a/test/load_balancer/transactions.json b/test/load_balancer/transactions.json index 0dd04ee854..ca9c145217 100644 --- a/test/load_balancer/transactions.json +++ b/test/load_balancer/transactions.json @@ -1616,6 +1616,50 @@ ] } ] + }, + { + "description": "pinned connection is released when session ended", + "operations": [ + { + "name": "startTransaction", + "object": "session0" + }, + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "x": 1 + }, + "session": "session0" + } + }, + { + "name": "commitTransaction", + "object": "session0" + }, + { + "name": "endSession", + "object": "session0" + } + ], + "expectEvents": [ + { + "client": "client0", + "eventType": "cmap", + "events": [ + { + "connectionReadyEvent": {} + }, + { + "connectionCheckedOutEvent": {} + }, + { + "connectionCheckedInEvent": {} + } + ] + } + ] } ] } diff --git a/test/mockupdb/test_cluster_time.py b/test/mockupdb/test_cluster_time.py index ea879b7ea3..42ca916971 100644 --- a/test/mockupdb/test_cluster_time.py +++ b/test/mockupdb/test_cluster_time.py @@ -123,50 +123,11 @@ def test_monitor(self): client = self.simple_client(server.uri, heartbeatFrequencyMS=500) - request = server.receives("ismaster") - # No $clusterTime in first ismaster, only in subsequent ones - self.assertNotIn("$clusterTime", request) - request.ok(reply) - - # Next exchange: client returns first clusterTime, we send the second. - request = server.receives("ismaster") - self.assertIn("$clusterTime", request) - self.assertEqual(request["$clusterTime"]["clusterTime"], cluster_time) - cluster_time = Timestamp(cluster_time.time, cluster_time.inc + 1) - reply["$clusterTime"] = {"clusterTime": cluster_time} - request.reply(reply) - - # Third exchange: client returns second clusterTime. - request = server.receives("ismaster") - self.assertEqual(request["$clusterTime"]["clusterTime"], cluster_time) - - # Return command error with a new clusterTime. - cluster_time = Timestamp(cluster_time.time, cluster_time.inc + 1) - error = { - "ok": 0, - "code": 211, - "errmsg": "Cache Reader No keys found for HMAC ...", - "$clusterTime": {"clusterTime": cluster_time}, - } - request.reply(error) - - # PyMongo 3.11+ closes the monitoring connection on command errors. - - # Fourth exchange: the Monitor closes the connection and runs the - # handshake on a new connection. - request = server.receives("ismaster") - # No $clusterTime in first ismaster, only in subsequent ones - self.assertNotIn("$clusterTime", request) - - # Reply without $clusterTime. - reply.pop("$clusterTime") - request.reply(reply) - - # Fifth exchange: the Monitor attempt uses the clusterTime from - # the previous isMaster error. - request = server.receives("ismaster") - self.assertEqual(request["$clusterTime"]["clusterTime"], cluster_time) - request.reply(reply) + for _ in range(3): + request = server.receives("ismaster") + # No $clusterTime in heartbeats or handshakes. + self.assertNotIn("$clusterTime", request) + request.ok(reply) client.close() def test_collection_bulk_error(self): diff --git a/test/mod_wsgi_test/test_client.py b/test/mod_wsgi_test/test_client.py index 88eeb7a57e..c122863bfa 100644 --- a/test/mod_wsgi_test/test_client.py +++ b/test/mod_wsgi_test/test_client.py @@ -24,7 +24,7 @@ from urllib.request import urlopen -def parse_args(): +def parse_args(args=None): parser = OptionParser( """usage: %prog [options] mode url [...] @@ -70,7 +70,7 @@ def parse_args(): ) try: - options, args = parser.parse_args() + options, args = parser.parse_args(args or sys.argv[1:]) mode, urls = args[0], args[1:] except (ValueError, IndexError): parser.print_usage() @@ -103,11 +103,11 @@ def __init__(self, options, urls, nrequests_per_thread): def run(self): for _i in range(self.nrequests_per_thread): try: - get(urls) + get(self.urls) except Exception as e: print(e) - if not options.continue_: + if not self.options.continue_: thread.interrupt_main() thread.exit() @@ -117,7 +117,7 @@ def run(self): URLGetterThread.counter += 1 counter = URLGetterThread.counter - should_print = options.verbose and not counter % 1000 + should_print = self.options.verbose and not counter % 1000 if should_print: print(counter) diff --git a/test/ocsp/test_ocsp.py b/test/ocsp/test_ocsp.py index a42b3a34ee..b20eaa35d6 100644 --- a/test/ocsp/test_ocsp.py +++ b/test/ocsp/test_ocsp.py @@ -19,6 +19,7 @@ import os import sys import unittest +from pathlib import Path import pytest @@ -38,15 +39,10 @@ FORMAT = "%(asctime)s %(levelname)s %(module)s %(message)s" logging.basicConfig(format=FORMAT, level=logging.DEBUG) -if sys.platform == "win32": - # The non-stapled OCSP endpoint check is slow on Windows. - TIMEOUT_MS = 5000 -else: - TIMEOUT_MS = 500 - def _connect(options): - uri = f"mongodb://localhost:27017/?serverSelectionTimeoutMS={TIMEOUT_MS}&tlsCAFile={CA_FILE}&{options}" + assert CA_FILE is not None + uri = f"mongodb://localhost:27017/?serverSelectionTimeoutMS=10000&tlsCAFile={Path(CA_FILE).as_posix()}&{options}" print(uri) try: client = pymongo.MongoClient(uri) diff --git a/test/performance/async_perf_test.py b/test/performance/async_perf_test.py new file mode 100644 index 0000000000..969437f9c9 --- /dev/null +++ b/test/performance/async_perf_test.py @@ -0,0 +1,488 @@ +# Copyright 2015-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Asynchronous Tests for the MongoDB Driver Performance Benchmarking Spec. + +See https://github.com/mongodb/specifications/blob/master/source/benchmarking/benchmarking.md + + +To set up the benchmarks locally:: + + python -m pip install simplejson + git clone --depth 1 https://github.com/mongodb/specifications.git + pushd specifications/source/benchmarking/data + tar xf extended_bson.tgz + tar xf parallel.tgz + tar xf single_and_multi_document.tgz + popd + export TEST_PATH="specifications/source/benchmarking/data" + export OUTPUT_FILE="results.json" + +Then to run all benchmarks quickly:: + + FASTBENCH=1 python test/performance/async_perf_test.py -v + +To run individual benchmarks quickly:: + + FASTBENCH=1 python test/performance/async_perf_test.py -v TestRunCommand TestFindManyAndEmptyCursor +""" +from __future__ import annotations + +import asyncio +import os +import sys +import tempfile +import time +import warnings +from typing import Any, List, Optional, Union + +import pytest + +try: + import simplejson as json +except ImportError: + import json # type: ignore[no-redef] + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncPyMongoTestCase, async_client_context, unittest + +from bson import encode +from gridfs import AsyncGridFSBucket +from pymongo import ( + DeleteOne, + InsertOne, + ReplaceOne, +) + +pytestmark = pytest.mark.perf + +# Spec says to use at least 1 minute cumulative execution time and up to 100 iterations or 5 minutes but that +# makes the benchmarks too slow. Instead, we use at least 30 seconds and at most 60 seconds. +NUM_ITERATIONS = 100 +MIN_ITERATION_TIME = 30 +MAX_ITERATION_TIME = 120 +NUM_DOCS = 10000 +# When debugging or prototyping it's often useful to run the benchmarks locally, set FASTBENCH=1 to run quickly. +if bool(os.getenv("FASTBENCH")): + NUM_ITERATIONS = 2 + MIN_ITERATION_TIME = 1 + MAX_ITERATION_TIME = 30 + NUM_DOCS = 1000 + +TEST_PATH = os.environ.get( + "TEST_PATH", os.path.join(os.path.dirname(os.path.realpath(__file__)), os.path.join("data")) +) + +OUTPUT_FILE = os.environ.get("OUTPUT_FILE") + +result_data: List = [] + + +def tearDownModule(): + output = json.dumps(result_data, indent=4) + if OUTPUT_FILE: + with open(OUTPUT_FILE, "w") as opf: + opf.write(output) + else: + print(output) + + +class Timer: + def __enter__(self): + self.start = time.monotonic() + return self + + def __exit__(self, *args): + self.end = time.monotonic() + self.interval = self.end - self.start + + +async def concurrent(n_tasks, func): + tasks = [func() for _ in range(n_tasks)] + await asyncio.gather(*tasks) + + +class PerformanceTest: + dataset: str + data_size: int + fail: Any + n_tasks: int = 1 + did_init: bool = False + + async def asyncSetUp(self): + await async_client_context.init() + self.setup_time = time.monotonic() + + async def asyncTearDown(self): + duration = time.monotonic() - self.setup_time + # Remove "Test" so that TestFlatEncoding is reported as "FlatEncoding". + name = self.__class__.__name__[4:] + median = self.percentile(50) + megabytes_per_sec = (self.data_size * self.n_tasks) / median / 1000000 + print( + f"Completed {self.__class__.__name__} {megabytes_per_sec:.3f} MB/s, MEDIAN={self.percentile(50):.3f}s, " + f"total time={duration:.3f}s, iterations={len(self.results)}" + ) + result_data.append( + { + "info": { + "test_name": name, + "args": { + "tasks": self.n_tasks, + }, + }, + "metrics": [ + {"name": "megabytes_per_sec", "type": "MEDIAN", "value": megabytes_per_sec}, + ], + } + ) + + async def before(self): + pass + + async def do_task(self): + raise NotImplementedError + + async def after(self): + pass + + def percentile(self, percentile): + if hasattr(self, "results"): + sorted_results = sorted(self.results) + percentile_index = int(len(sorted_results) * percentile / 100) - 1 + return sorted_results[percentile_index] + else: + self.fail("Test execution failed") + return None + + async def runTest(self): + results = [] + start = time.monotonic() + i = 0 + while True: + i += 1 + await self.before() + with Timer() as timer: + if self.n_tasks == 1: + await self.do_task() + else: + await concurrent(self.n_tasks, self.do_task) + await self.after() + results.append(timer.interval) + duration = time.monotonic() - start + if duration > MIN_ITERATION_TIME and i >= NUM_ITERATIONS: + break + if i >= NUM_ITERATIONS: + break + if duration > MAX_ITERATION_TIME: + with warnings.catch_warnings(): + warnings.simplefilter("default") + warnings.warn( + f"{self.__class__.__name__} timed out after {MAX_ITERATION_TIME}s, completed {i}/{NUM_ITERATIONS} iterations." + ) + + break + + self.results = results + + +# SINGLE-DOC BENCHMARKS +class TestRunCommand(PerformanceTest, AsyncPyMongoTestCase): + data_size = len(encode({"hello": True})) * NUM_DOCS + + async def asyncSetUp(self): + await super().asyncSetUp() + self.client = async_client_context.client + await self.client.drop_database("perftest") + + async def do_task(self): + command = self.client.perftest.command + for _ in range(NUM_DOCS): + await command("hello", True) + + +class TestRunCommand8Tasks(TestRunCommand): + n_tasks = 8 + + +class TestRunCommand80Tasks(TestRunCommand): + n_tasks = 80 + + +class TestRunCommandUnlimitedTasks(TestRunCommand): + async def do_task(self): + command = self.client.perftest.command + await asyncio.gather(*[command("hello", True) for _ in range(NUM_DOCS)]) + + +class TestDocument(PerformanceTest): + async def asyncSetUp(self): + await super().asyncSetUp() + # Location of test data. + with open( # noqa: ASYNC101 + os.path.join(TEST_PATH, os.path.join("single_and_multi_document", self.dataset)) + ) as data: + self.document = json.loads(data.read()) + + self.client = async_client_context.client + await self.client.drop_database("perftest") + + async def asyncTearDown(self): + await super().asyncTearDown() + await self.client.drop_database("perftest") + + async def before(self): + self.corpus = await self.client.perftest.create_collection("corpus") + + async def after(self): + await self.client.perftest.drop_collection("corpus") + + +class FindTest(TestDocument): + dataset = "tweet.json" + + async def asyncSetUp(self): + await super().asyncSetUp() + self.data_size = len(encode(self.document)) * NUM_DOCS + documents = [self.document.copy() for _ in range(NUM_DOCS)] + self.corpus = self.client.perftest.corpus + result = await self.corpus.insert_many(documents) + self.inserted_ids = result.inserted_ids + + async def before(self): + pass + + async def after(self): + pass + + +class TestFindOneByID(FindTest, AsyncPyMongoTestCase): + async def do_task(self): + find_one = self.corpus.find_one + for _id in self.inserted_ids: + await find_one({"_id": _id}) + + +class TestFindOneByID8Tasks(TestFindOneByID): + n_tasks = 8 + + +class TestFindOneByID80Tasks(TestFindOneByID): + n_tasks = 80 + + +class TestFindOneByIDUnlimitedTasks(TestFindOneByID): + async def do_task(self): + find_one = self.corpus.find_one + await asyncio.gather(*[find_one({"_id": _id}) for _id in self.inserted_ids]) + + +class SmallDocInsertTest(TestDocument): + dataset = "small_doc.json" + + async def asyncSetUp(self): + await super().asyncSetUp() + self.data_size = len(encode(self.document)) * NUM_DOCS + self.documents = [self.document.copy() for _ in range(NUM_DOCS)] + + +class SmallDocMixedTest(TestDocument): + dataset = "small_doc.json" + + async def asyncSetUp(self): + await super().asyncSetUp() + self.data_size = len(encode(self.document)) * NUM_DOCS * 2 + self.documents = [self.document.copy() for _ in range(NUM_DOCS)] + + +class TestSmallDocInsertOne(SmallDocInsertTest, AsyncPyMongoTestCase): + async def do_task(self): + insert_one = self.corpus.insert_one + for doc in self.documents: + await insert_one(doc) + + +class TestSmallDocInsertOneUnlimitedTasks(SmallDocInsertTest, AsyncPyMongoTestCase): + async def do_task(self): + insert_one = self.corpus.insert_one + await asyncio.gather(*[insert_one(doc) for doc in self.documents]) + + +class LargeDocInsertTest(TestDocument): + dataset = "large_doc.json" + + async def asyncSetUp(self): + await super().asyncSetUp() + n_docs = 10 + self.data_size = len(encode(self.document)) * n_docs + self.documents = [self.document.copy() for _ in range(n_docs)] + + +class TestLargeDocInsertOne(LargeDocInsertTest, AsyncPyMongoTestCase): + async def do_task(self): + insert_one = self.corpus.insert_one + for doc in self.documents: + await insert_one(doc) + + +class TestLargeDocInsertOneUnlimitedTasks(LargeDocInsertTest, AsyncPyMongoTestCase): + async def do_task(self): + insert_one = self.corpus.insert_one + await asyncio.gather(*[insert_one(doc) for doc in self.documents]) + + +# MULTI-DOC BENCHMARKS +class TestFindManyAndEmptyCursor(FindTest, AsyncPyMongoTestCase): + async def do_task(self): + await self.corpus.find().to_list() + + +class TestFindManyAndEmptyCursor8Tasks(TestFindManyAndEmptyCursor): + n_tasks = 8 + + +class TestFindManyAndEmptyCursor80Tasks(TestFindManyAndEmptyCursor): + n_tasks = 80 + + +class TestSmallDocBulkInsert(SmallDocInsertTest, AsyncPyMongoTestCase): + async def do_task(self): + await self.corpus.insert_many(self.documents, ordered=True) + + +class TestSmallDocCollectionBulkInsert(SmallDocInsertTest, AsyncPyMongoTestCase): + async def asyncSetUp(self): + await super().asyncSetUp() + self.models = [] + for doc in self.documents: + self.models.append(InsertOne(namespace="perftest.corpus", document=doc)) + + async def do_task(self): + await self.corpus.bulk_write(self.models, ordered=True) + + +class TestSmallDocClientBulkInsert(SmallDocInsertTest, AsyncPyMongoTestCase): + @async_client_context.require_version_min(8, 0, 0, -24) + async def asyncSetUp(self): + await super().asyncSetUp() + self.models = [] + for doc in self.documents: + self.models.append(InsertOne(namespace="perftest.corpus", document=doc)) + + @async_client_context.require_version_min(8, 0, 0, -24) + async def do_task(self): + await self.client.bulk_write(self.models, ordered=True) + + +class TestSmallDocBulkMixedOps(SmallDocMixedTest, AsyncPyMongoTestCase): + async def asyncSetUp(self): + await super().asyncSetUp() + self.models: list[Union[InsertOne, ReplaceOne, DeleteOne]] = [] + for doc in self.documents: + self.models.append(InsertOne(document=doc)) + self.models.append(ReplaceOne(filter={}, replacement=doc.copy(), upsert=True)) + self.models.append(DeleteOne(filter={})) + + async def do_task(self): + await self.corpus.bulk_write(self.models, ordered=True) + + +class TestSmallDocClientBulkMixedOps(SmallDocMixedTest, AsyncPyMongoTestCase): + @async_client_context.require_version_min(8, 0, 0, -24) + async def asyncSetUp(self): + await super().asyncSetUp() + self.models: list[Union[InsertOne, ReplaceOne, DeleteOne]] = [] + for doc in self.documents: + self.models.append(InsertOne(namespace="perftest.corpus", document=doc)) + self.models.append( + ReplaceOne( + namespace="perftest.corpus", filter={}, replacement=doc.copy(), upsert=True + ) + ) + self.models.append(DeleteOne(namespace="perftest.corpus", filter={})) + + @async_client_context.require_version_min(8, 0, 0, -24) + async def do_task(self): + await self.client.bulk_write(self.models, ordered=True) + + +class TestLargeDocBulkInsert(LargeDocInsertTest, AsyncPyMongoTestCase): + async def do_task(self): + await self.corpus.insert_many(self.documents, ordered=True) + + +class TestLargeDocCollectionBulkInsert(LargeDocInsertTest, AsyncPyMongoTestCase): + async def asyncSetUp(self): + await super().asyncSetUp() + self.models = [] + for doc in self.documents: + self.models.append(InsertOne(namespace="perftest.corpus", document=doc)) + + async def do_task(self): + await self.corpus.bulk_write(self.models, ordered=True) + + +class TestLargeDocClientBulkInsert(LargeDocInsertTest, AsyncPyMongoTestCase): + @async_client_context.require_version_min(8, 0, 0, -24) + async def asyncSetUp(self): + await super().asyncSetUp() + self.models = [] + for doc in self.documents: + self.models.append(InsertOne(namespace="perftest.corpus", document=doc)) + + @async_client_context.require_version_min(8, 0, 0, -24) + async def do_task(self): + await self.client.bulk_write(self.models, ordered=True) + + +class GridFsTest(PerformanceTest): + async def asyncSetUp(self): + await super().asyncSetUp() + self.client = async_client_context.client + await self.client.drop_database("perftest") + + gridfs_path = os.path.join( + TEST_PATH, os.path.join("single_and_multi_document", "gridfs_large.bin") + ) + with open(gridfs_path, "rb") as data: # noqa: ASYNC101 + self.document = data.read() + self.data_size = len(self.document) + self.bucket = AsyncGridFSBucket(self.client.perftest) + + async def asyncTearDown(self): + await super().asyncTearDown() + await self.client.drop_database("perftest") + + +class TestGridFsUpload(GridFsTest, AsyncPyMongoTestCase): + async def before(self): + # Create the bucket. + await self.bucket.upload_from_stream("init", b"x") + + async def do_task(self): + await self.bucket.upload_from_stream("gridfstest", self.document) + + +class TestGridFsDownload(GridFsTest, AsyncPyMongoTestCase): + async def asyncSetUp(self): + await super().asyncSetUp() + self.uploaded_id = await self.bucket.upload_from_stream("gridfstest", self.document) + + async def do_task(self): + await (await self.bucket.open_download_stream(self.uploaded_id)).read() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/performance/perf_test.py b/test/performance/perf_test.py index 6e269e25b0..39487eff6d 100644 --- a/test/performance/perf_test.py +++ b/test/performance/perf_test.py @@ -443,6 +443,17 @@ def do_task(self): self.corpus.insert_many(self.documents, ordered=True) +class TestSmallDocCollectionBulkInsert(SmallDocInsertTest, unittest.TestCase): + def setUp(self): + super().setUp() + self.models = [] + for doc in self.documents: + self.models.append(InsertOne(namespace="perftest.corpus", document=doc)) + + def do_task(self): + self.corpus.bulk_write(self.models, ordered=True) + + class TestSmallDocClientBulkInsert(SmallDocInsertTest, unittest.TestCase): @client_context.require_version_min(8, 0, 0, -24) def setUp(self): @@ -493,6 +504,17 @@ def do_task(self): self.corpus.insert_many(self.documents, ordered=True) +class TestLargeDocCollectionBulkInsert(LargeDocInsertTest, unittest.TestCase): + def setUp(self): + super().setUp() + self.models = [] + for doc in self.documents: + self.models.append(InsertOne(namespace="perftest.corpus", document=doc)) + + def do_task(self): + self.corpus.bulk_write(self.models, ordered=True) + + class TestLargeDocClientBulkInsert(LargeDocInsertTest, unittest.TestCase): @client_context.require_version_min(8, 0, 0, -24) def setUp(self): diff --git a/test/retryable_reads/unified/estimatedDocumentCount.json b/test/retryable_reads/unified/estimatedDocumentCount.json index 75a676b9b6..2ee29f6799 100644 --- a/test/retryable_reads/unified/estimatedDocumentCount.json +++ b/test/retryable_reads/unified/estimatedDocumentCount.json @@ -195,7 +195,7 @@ "object": "collection1", "name": "estimatedDocumentCount", "expectError": { - "isError": true + "isClientError": true } } ], @@ -241,7 +241,7 @@ "object": "collection0", "name": "estimatedDocumentCount", "expectError": { - "isError": true + "isClientError": true } } ], diff --git a/test/retryable_writes/unified/insertOne-serverErrors.json b/test/retryable_writes/unified/insertOne-serverErrors.json index f404adcaf4..8edafb7029 100644 --- a/test/retryable_writes/unified/insertOne-serverErrors.json +++ b/test/retryable_writes/unified/insertOne-serverErrors.json @@ -739,7 +739,7 @@ ] }, { - "description": "InsertOne fails after WriteConcernError WriteConcernFailed", + "description": "InsertOne fails after WriteConcernError WriteConcernTimeout", "operations": [ { "name": "failPoint", @@ -757,7 +757,6 @@ ], "writeConcernError": { "code": 64, - "codeName": "WriteConcernFailed", "errmsg": "waiting for replication timed out", "errInfo": { "wtimeout": true diff --git a/test/run_command/unified/runCommand.json b/test/run_command/unified/runCommand.json index 007e514bd7..fde9de92e6 100644 --- a/test/run_command/unified/runCommand.json +++ b/test/run_command/unified/runCommand.json @@ -229,7 +229,6 @@ { "topologies": [ "replicaset", - "sharded-replicaset", "load-balanced", "sharded" ] @@ -493,7 +492,7 @@ { "minServerVersion": "4.2", "topologies": [ - "sharded-replicaset", + "sharded", "load-balanced" ] } diff --git a/test/sessions/driver-sessions-dirty-session-errors.json b/test/sessions/driver-sessions-dirty-session-errors.json index 361ea83d7b..6aa1da1df5 100644 --- a/test/sessions/driver-sessions-dirty-session-errors.json +++ b/test/sessions/driver-sessions-dirty-session-errors.json @@ -11,7 +11,7 @@ { "minServerVersion": "4.1.8", "topologies": [ - "sharded-replicaset" + "sharded" ] } ], diff --git a/test/sessions/snapshot-sessions-unsupported-ops.json b/test/sessions/snapshot-sessions-unsupported-ops.json index 1021b7f264..c41f74d337 100644 --- a/test/sessions/snapshot-sessions-unsupported-ops.json +++ b/test/sessions/snapshot-sessions-unsupported-ops.json @@ -6,7 +6,7 @@ "minServerVersion": "5.0", "topologies": [ "replicaset", - "sharded-replicaset" + "sharded" ] } ], diff --git a/test/sessions/snapshot-sessions.json b/test/sessions/snapshot-sessions.json index 75b577b039..260f8b6f48 100644 --- a/test/sessions/snapshot-sessions.json +++ b/test/sessions/snapshot-sessions.json @@ -6,7 +6,7 @@ "minServerVersion": "5.0", "topologies": [ "replicaset", - "sharded-replicaset" + "sharded" ] } ], diff --git a/test/test_auth.py b/test/test_auth.py index 345d16121b..27f6743fae 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -30,7 +30,7 @@ client_context, unittest, ) -from test.utils import AllowListEventListener, delay, ignore_deprecations +from test.utils_shared import AllowListEventListener, delay, ignore_deprecations import pytest diff --git a/test/test_auth_spec.py b/test/test_auth_spec.py index 3c3a1a67ae..9ba15e8d78 100644 --- a/test/test_auth_spec.py +++ b/test/test_auth_spec.py @@ -22,6 +22,8 @@ import warnings from test import PyMongoTestCase +import pytest + sys.path[0:0] = [""] from test import unittest @@ -30,6 +32,8 @@ from pymongo import MongoClient from pymongo.synchronous.auth_oidc import OIDCCallback +pytestmark = pytest.mark.auth + _IS_SYNC = True _TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "auth") diff --git a/test/test_bson.py b/test/test_bson.py index e601be4915..1616c513c2 100644 --- a/test/test_bson.py +++ b/test/test_bson.py @@ -33,7 +33,7 @@ sys.path[0:0] = [""] from test import qcheck, unittest -from test.utils import ExceptionCatchingThread +from test.helpers import ExceptionCatchingTask import bson from bson import ( @@ -558,7 +558,7 @@ def test_unknown_type(self): decode(bs) except Exception as exc: self.assertTrue(isinstance(exc, InvalidBSON)) - self.assertTrue(part in str(exc)) + self.assertIn(part, str(exc)) else: self.fail("Failed to raise an exception.") @@ -809,6 +809,64 @@ def test_vector(self): dtype=BinaryVectorDtype.PACKED_BIT, ) # type: ignore[call-overload] + def assertRepr(self, obj): + new_obj = eval(repr(obj)) + self.assertEqual(type(new_obj), type(obj)) + self.assertEqual(repr(new_obj), repr(obj)) + + def test_binaryvector_repr(self): + """Tests of repr(BinaryVector)""" + + data = [1 / 127, -7 / 6] + one = BinaryVector(data, BinaryVectorDtype.FLOAT32) + self.assertEqual( + repr(one), f"BinaryVector(dtype=BinaryVectorDtype.FLOAT32, padding=0, data={data})" + ) + self.assertRepr(one) + + data = [127, 7] + two = BinaryVector(data, BinaryVectorDtype.INT8) + self.assertEqual( + repr(two), f"BinaryVector(dtype=BinaryVectorDtype.INT8, padding=0, data={data})" + ) + self.assertRepr(two) + + three = BinaryVector(data, BinaryVectorDtype.INT8, padding=0) + self.assertEqual( + repr(three), f"BinaryVector(dtype=BinaryVectorDtype.INT8, padding=0, data={data})" + ) + self.assertRepr(three) + + four = BinaryVector(data, BinaryVectorDtype.PACKED_BIT, padding=3) + self.assertEqual( + repr(four), f"BinaryVector(dtype=BinaryVectorDtype.PACKED_BIT, padding=3, data={data})" + ) + self.assertRepr(four) + + zero = BinaryVector([], BinaryVectorDtype.INT8) + self.assertEqual( + repr(zero), "BinaryVector(dtype=BinaryVectorDtype.INT8, padding=0, data=[])" + ) + self.assertRepr(zero) + + def test_binaryvector_equality(self): + """Tests of == __eq__""" + self.assertEqual( + BinaryVector([1.2, 1 - 1 / 3], BinaryVectorDtype.FLOAT32, 0), + BinaryVector([1.2, 1 - 1.0 / 3.0], BinaryVectorDtype.FLOAT32, 0), + ) + self.assertNotEqual( + BinaryVector([1.2, 1 - 1 / 3], BinaryVectorDtype.FLOAT32, 0), + BinaryVector([1.2, 6.0 / 9.0], BinaryVectorDtype.FLOAT32, 0), + ) + self.assertEqual( + BinaryVector([], BinaryVectorDtype.FLOAT32, 0), + BinaryVector([], BinaryVectorDtype.FLOAT32, 0), + ) + self.assertNotEqual( + BinaryVector([1], BinaryVectorDtype.INT8), BinaryVector([2], BinaryVectorDtype.INT8) + ) + def test_unicode_regex(self): """Tests we do not get a segfault for C extension on unicode RegExs. This had been happening. @@ -1075,7 +1133,7 @@ def target(i): my_int = type(f"MyInt_{i}_{j}", (int,), {}) bson.encode({"my_int": my_int()}) - threads = [ExceptionCatchingThread(target=target, args=(i,)) for i in range(3)] + threads = [ExceptionCatchingTask(target=target, args=(i,)) for i in range(3)] for t in threads: t.start() @@ -1114,7 +1172,7 @@ def __repr__(self): def test_doc_in_invalid_document_error_message_mapping(self): class MyMapping(abc.Mapping): - def keys(): + def keys(self): return ["t"] def __getitem__(self, name): diff --git a/test/test_bson_binary_vector.py b/test/test_bson_binary_vector.py index 00c82bbb65..9bfdcbfb9a 100644 --- a/test/test_bson_binary_vector.py +++ b/test/test_bson_binary_vector.py @@ -16,12 +16,11 @@ import binascii import codecs -import json import struct from pathlib import Path from test import unittest -from bson import decode, encode +from bson import decode, encode, json_util from bson.binary import Binary, BinaryVectorDtype _TEST_PATH = Path(__file__).parent / "bson_binary_vector" @@ -49,7 +48,7 @@ def create_test(case_spec): def run_test(self): for test_case in case_spec.get("tests", []): description = test_case["description"] - vector_exp = test_case["vector"] + vector_exp = test_case.get("vector", []) dtype_hex_exp = test_case["dtype_hex"] dtype_alias_exp = test_case.get("dtype_alias") padding_exp = test_case.get("padding", 0) @@ -62,9 +61,6 @@ def run_test(self): cB_exp = binascii.unhexlify(canonical_bson_exp.encode("utf8")) decoded_doc = decode(cB_exp) binary_obs = decoded_doc[test_key] - # Handle special float cases like '-inf' - if dtype_exp in [BinaryVectorDtype.FLOAT32]: - vector_exp = [float(x) for x in vector_exp] # Test round-tripping canonical bson. self.assertEqual(encode(decoded_doc), cB_exp, description) @@ -76,9 +72,13 @@ def run_test(self): self.assertEqual( vector_obs.dtype, BinaryVectorDtype[dtype_alias_exp], description ) - self.assertEqual(vector_obs.data, vector_exp, description) - self.assertEqual(vector_obs.padding, padding_exp, description) - + if dtype_exp in [BinaryVectorDtype.FLOAT32]: + [ + self.assertAlmostEqual(vector_obs.data[i], vector_exp[i], delta=1e-5) + for i in range(len(vector_exp)) + ] + else: + self.assertEqual(vector_obs.data, vector_exp, description) # Test Binary Vector to BSON vector_exp = Binary.from_vector(vector_exp, dtype_exp, padding_exp) cB_obs = binascii.hexlify(encode({test_key: vector_exp})).decode().upper() @@ -86,7 +86,13 @@ def run_test(self): else: with self.assertRaises((struct.error, ValueError), msg=description): + # Tests Binary.from_vector Binary.from_vector(vector_exp, dtype_exp, padding_exp) + # Tests Binary.as_vector + cB_exp = binascii.unhexlify(canonical_bson_exp.encode("utf8")) + decoded_doc = decode(cB_exp) + binary_obs = decoded_doc[test_key] + binary_obs.as_vector() return run_test @@ -94,7 +100,7 @@ def run_test(self): def create_tests(): for filename in _TEST_PATH.glob("*.json"): with codecs.open(str(filename), encoding="utf-8") as test_file: - test_method = create_test(json.load(test_file)) + test_method = create_test(json_util.loads(test_file.read())) setattr(TestBSONBinaryVector, "test_" + filename.stem, test_method) diff --git a/test/test_bulk.py b/test/test_bulk.py index 6d29ff510a..8a863cc49b 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -24,7 +24,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, remove_all_users, unittest -from test.utils import wait_until +from test.utils_shared import wait_until from bson.binary import Binary, UuidRepresentation from bson.codec_options import CodecOptions @@ -959,7 +959,6 @@ def cause_wtimeout(self, requests, ordered): @client_context.require_replica_set @client_context.require_secondaries_count(1) def test_write_concern_failure_ordered(self): - self.skipTest("Skipping until PYTHON-4865 is resolved.") details = None # Ensure we don't raise on wnote. diff --git a/test/test_change_stream.py b/test/test_change_stream.py index 4ed21f55cf..6099829031 100644 --- a/test/test_change_stream.py +++ b/test/test_change_stream.py @@ -36,7 +36,7 @@ unittest, ) from test.unified_format import generate_test_classes -from test.utils import ( +from test.utils_shared import ( AllowListEventListener, EventListener, OvertCommandListener, @@ -406,7 +406,14 @@ def test_change_operations(self): expected_update_description = {"updatedFields": {"new": 1}, "removedFields": ["foo"]} if client_context.version.at_least(4, 5, 0): expected_update_description["truncatedArrays"] = [] - self.assertEqual(expected_update_description, change["updateDescription"]) + self.assertEqual( + expected_update_description, + { + k: v + for k, v in change["updateDescription"].items() + if k in expected_update_description + }, + ) # Replace. self.watched_collection().replace_one({"new": 1}, {"foo": "bar"}) change = change_stream.next() diff --git a/test/test_client.py b/test/test_client.py index 2a33077f5f..14da72a8bc 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -61,17 +61,19 @@ from test.pymongo_mocks import MockClient from test.test_binary import BinaryData from test.utils import ( + assertRaisesExactly, + get_pool, + wait_until, +) +from test.utils_shared import ( NTHREADS, CMAPListener, FunctionCallRecorder, - assertRaisesExactly, delay, - get_pool, gevent_monkey_patched, is_greenthread_patched, lazy_client_trial, one, - wait_until, ) import bson @@ -100,6 +102,7 @@ NetworkTimeout, OperationFailure, ServerSelectionTimeoutError, + WaitQueueTimeoutError, WriteConcernError, ) from pymongo.monitoring import ServerHeartbeatListener, ServerHeartbeatStartedEvent @@ -502,13 +505,13 @@ def test_uri_option_precedence(self): def test_connection_timeout_ms_propagates_to_DNS_resolver(self): # Patch the resolver. - from pymongo.srv_resolver import _resolve + from pymongo.synchronous.srv_resolver import _resolve patched_resolver = FunctionCallRecorder(_resolve) - pymongo.srv_resolver._resolve = patched_resolver + pymongo.synchronous.srv_resolver._resolve = patched_resolver def reset_resolver(): - pymongo.srv_resolver._resolve = _resolve + pymongo.synchronous.srv_resolver._resolve = _resolve self.addCleanup(reset_resolver) @@ -597,7 +600,7 @@ def test_validate_suggestion(self): with self.assertRaisesRegex(ConfigurationError, expected): MongoClient(**{typo: "standard"}) # type: ignore[arg-type] - @patch("pymongo.srv_resolver._SrvResolver.get_hosts") + @patch("pymongo.synchronous.srv_resolver._SrvResolver.get_hosts") def test_detected_environment_logging(self, mock_get_hosts): normal_hosts = [ "normal.host.com", @@ -619,7 +622,7 @@ def test_detected_environment_logging(self, mock_get_hosts): logs = [record.getMessage() for record in cm.records if record.name == "pymongo.client"] self.assertEqual(len(logs), 7) - @patch("pymongo.srv_resolver._SrvResolver.get_hosts") + @patch("pymongo.synchronous.srv_resolver._SrvResolver.get_hosts") def test_detected_environment_warning(self, mock_get_hosts): with self._caplog.at_level(logging.WARN): normal_hosts = [ @@ -821,6 +824,58 @@ def test_init_disconnected_with_auth(self): with self.assertRaises(ConnectionFailure): c.pymongo_test.test.find_one() + @client_context.require_replica_set + @client_context.require_no_load_balancer + @client_context.require_tls + def test_init_disconnected_with_srv(self): + c = self.rs_or_single_client( + "mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True + ) + # nodes returns an empty set if not connected + self.assertEqual(c.nodes, frozenset()) + # topology_description returns the initial seed description if not connected + topology_description = c.topology_description + self.assertEqual(topology_description.topology_type, TOPOLOGY_TYPE.Unknown) + self.assertEqual( + { + ("test1.test.build.10gen.cc", None): ServerDescription( + ("test1.test.build.10gen.cc", None) + ) + }, + topology_description.server_descriptions(), + ) + + # address causes client to block until connected + self.assertIsNotNone(c.address) + # Initial seed topology and connected topology have the same ID + self.assertEqual( + c._topology._topology_id, topology_description._topology_settings._topology_id + ) + c.close() + + c = self.rs_or_single_client( + "mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True + ) + # primary causes client to block until connected + c.primary + self.assertIsNotNone(c._topology) + c.close() + + c = self.rs_or_single_client( + "mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True + ) + # secondaries causes client to block until connected + c.secondaries + self.assertIsNotNone(c._topology) + c.close() + + c = self.rs_or_single_client( + "mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True + ) + # arbiters causes client to block until connected + c.arbiters + self.assertIsNotNone(c._topology) + def test_equality(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) c = self.rs_or_single_client(seed, connect=False) @@ -905,6 +960,15 @@ def test_repr(self): with eval(the_repr) as client_two: self.assertEqual(client_two, client) + def test_repr_srv_host(self): + client = MongoClient("mongodb+srv://test1.test.build.10gen.cc/", connect=False) + # before srv resolution + self.assertIn("host='mongodb+srv://test1.test.build.10gen.cc'", repr(client)) + client._connect() + # after srv resolution + self.assertIn("host=['localhost.test.build.10gen.cc:", repr(client)) + client.close() + def test_getters(self): wait_until(lambda: client_context.nodes == self.client.nodes, "find all nodes") @@ -1221,7 +1285,6 @@ def test_socket_timeout(self): no_timeout = self.client timeout_sec = 1 timeout = self.rs_or_single_client(socketTimeoutMS=1000 * timeout_sec) - self.addCleanup(timeout.close) no_timeout.pymongo_test.drop_collection("test") no_timeout.pymongo_test.test.insert_one({"x": 1}) @@ -1270,13 +1333,21 @@ def test_server_selection_timeout(self): self.assertAlmostEqual(30, client.options.server_selection_timeout) def test_waitQueueTimeoutMS(self): - client = self.rs_or_single_client(waitQueueTimeoutMS=2000) - self.assertEqual((get_pool(client)).opts.wait_queue_timeout, 2) + listener = CMAPListener() + client = self.rs_or_single_client( + waitQueueTimeoutMS=10, maxPoolSize=1, event_listeners=[listener] + ) + pool = get_pool(client) + self.assertEqual(pool.opts.wait_queue_timeout, 0.01) + with pool.checkout(): + with self.assertRaises(WaitQueueTimeoutError): + client.test.command("ping") + self.assertFalse(listener.events_by_type(monitoring.PoolClearedEvent)) def test_socketKeepAlive(self): pool = get_pool(self.client) with pool.checkout() as conn: - keepalive = conn.conn.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE) + keepalive = conn.conn.sock.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE) self.assertTrue(keepalive) @no_type_check @@ -1747,6 +1818,29 @@ def stall_connect(*args, **kwargs): # Each ping command should not take more than 2 seconds self.assertLess(total, 2) + def test_background_connections_log_on_error(self): + with self.assertLogs("pymongo.client", level="ERROR") as cm: + client = self.rs_or_single_client(minPoolSize=1) + # Create a single connection in the pool. + client.admin.command("ping") + + # Cause new connections to fail. + pool = get_pool(client) + + def fail_connect(*args, **kwargs): + raise Exception("failed to connect") + + pool.connect = fail_connect + # Un-patch Pool.connect to break the cyclic reference. + self.addCleanup(delattr, pool, "connect") + + pool.reset_without_pause() + + wait_until( + lambda: "failed to connect" in "".join(cm.output), "start creating connections" + ) + self.assertIn("MongoClient background task encountered an error", "".join(cm.output)) + @client_context.require_replica_set def test_direct_connection(self): # direct_connection=True should result in Single topology. @@ -1781,20 +1875,20 @@ def server_description_count(): return i gc.collect() - with client_knobs(min_heartbeat_interval=0.003): + with client_knobs(min_heartbeat_interval=0.002): client = self.simple_client( - "invalid:27017", heartbeatFrequencyMS=3, serverSelectionTimeoutMS=150 + "invalid:27017", heartbeatFrequencyMS=2, serverSelectionTimeoutMS=200 ) initial_count = server_description_count() with self.assertRaises(ServerSelectionTimeoutError): client.test.test.find_one() gc.collect() final_count = server_description_count() + client.close() # If a bug like PYTHON-2433 is reintroduced then too many # ServerDescriptions will be kept alive and this test will fail: - # AssertionError: 19 != 46 within 15 delta (27 difference) - # On Python 3.11 we seem to get more of a delta. - self.assertAlmostEqual(initial_count, final_count, delta=20) + # AssertionError: 11 != 47 within 20 delta (36 difference) + self.assertAlmostEqual(initial_count, final_count, delta=30) @client_context.require_failCommand_fail_point def test_network_error_message(self): @@ -1834,28 +1928,37 @@ def test_service_name_from_kwargs(self): srvServiceName="customname", connect=False, ) + client._connect() self.assertEqual(client._topology_settings.srv_service_name, "customname") + client.close() client = MongoClient( "mongodb+srv://user:password@test22.test.build.10gen.cc" "/?srvServiceName=shouldbeoverriden", srvServiceName="customname", connect=False, ) + client._connect() self.assertEqual(client._topology_settings.srv_service_name, "customname") + client.close() client = MongoClient( "mongodb+srv://user:password@test22.test.build.10gen.cc/?srvServiceName=customname", connect=False, ) + client._connect() self.assertEqual(client._topology_settings.srv_service_name, "customname") + client.close() def test_srv_max_hosts_kwarg(self): client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/") + client._connect() self.assertGreater(len(client.topology_description.server_descriptions()), 1) client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1) + client._connect() self.assertEqual(len(client.topology_description.server_descriptions()), 1) client = self.simple_client( "mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=1", srvmaxhosts=2 ) + client._connect() self.assertEqual(len(client.topology_description.server_descriptions()), 2) @unittest.skipIf( @@ -2399,7 +2502,7 @@ def test_reconnect(self): # MongoClient discovers it's alone. The first attempt raises either # ServerSelectionTimeoutError or AutoReconnect (from - # AsyncMockPool.get_socket). + # MockPool.get_socket). with self.assertRaises(AutoReconnect): c.db.collection.find_one() diff --git a/test/test_client_bulk_write.py b/test/test_client_bulk_write.py index f8d92668ea..866b179c9e 100644 --- a/test/test_client_bulk_write.py +++ b/test/test_client_bulk_write.py @@ -25,7 +25,7 @@ client_context, unittest, ) -from test.utils import ( +from test.utils_shared import ( OvertCommandListener, ) from unittest.mock import patch @@ -647,7 +647,6 @@ def test_timeout_in_multi_batch_bulk_write(self): _OVERHEAD = 500 internal_client = self.rs_or_single_client(timeoutMS=None) - self.addCleanup(internal_client.close) collection = internal_client.db["coll"] self.addCleanup(collection.drop) diff --git a/test/test_client_context.py b/test/test_client_context.py index e807ac5f5f..ef3633a8b0 100644 --- a/test/test_client_context.py +++ b/test/test_client_context.py @@ -47,20 +47,14 @@ def test_serverless(self): ) def test_enableTestCommands_is_disabled(self): - if not os.environ.get("PYMONGO_DISABLE_TEST_COMMANDS"): - raise SkipTest("PYMONGO_DISABLE_TEST_COMMANDS is not set") + if not os.environ.get("DISABLE_TEST_COMMANDS"): + raise SkipTest("DISABLE_TEST_COMMANDS is not set") self.assertFalse( client_context.test_commands_enabled, - "enableTestCommands must be disabled when PYMONGO_DISABLE_TEST_COMMANDS is set.", + "enableTestCommands must be disabled when DISABLE_TEST_COMMANDS is set.", ) - def test_setdefaultencoding_worked(self): - if not os.environ.get("SETDEFAULTENCODING"): - raise SkipTest("SETDEFAULTENCODING is not set") - - self.assertEqual(sys.getdefaultencoding(), os.environ["SETDEFAULTENCODING"]) - def test_free_threading_is_enabled(self): if "free-threading build" not in sys.version: raise SkipTest("this test requires the Python free-threading build") diff --git a/test/test_collation.py b/test/test_collation.py index 06436f0638..5425551dc6 100644 --- a/test/test_collation.py +++ b/test/test_collation.py @@ -18,7 +18,7 @@ import functools import warnings from test import IntegrationTest, client_context, unittest -from test.utils import EventListener, OvertCommandListener +from test.utils_shared import EventListener, OvertCommandListener from typing import Any from pymongo.collation import ( diff --git a/test/test_collection.py b/test/test_collection.py index 8a862646eb..75c11383d0 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -21,6 +21,7 @@ import sys from codecs import utf_8_decode from collections import defaultdict +from test.utils import get_pool, is_mongos from typing import Any, Iterable, no_type_check from pymongo.synchronous.database import Database @@ -33,12 +34,10 @@ client_context, unittest, ) -from test.utils import ( +from test.utils_shared import ( IMPOSSIBLE_WRITE_CONCERN, EventListener, OvertCommandListener, - get_pool, - is_mongos, wait_until, ) diff --git a/test/test_comment.py b/test/test_comment.py index 9f9bf98640..b6c17c14fe 100644 --- a/test/test_comment.py +++ b/test/test_comment.py @@ -22,7 +22,7 @@ sys.path[0:0] = [""] from asyncio import iscoroutinefunction from test import IntegrationTest, client_context, unittest -from test.utils import OvertCommandListener +from test.utils_shared import OvertCommandListener from bson.dbref import DBRef from pymongo.operations import IndexModel diff --git a/test/test_connection_monitoring.py b/test/test_connection_monitoring.py index 05411d17ba..1405824453 100644 --- a/test/test_connection_monitoring.py +++ b/test/test_connection_monitoring.py @@ -15,20 +15,20 @@ """Execute Transactions Spec tests.""" from __future__ import annotations +import asyncio import os import sys import time +from pathlib import Path +from test.utils import get_pool, get_pools sys.path[0:0] = [""] -from test import IntegrationTest, client_knobs, unittest +from test import IntegrationTest, client_context, client_knobs, unittest from test.pymongo_mocks import DummyMonitor -from test.utils import ( +from test.utils_shared import ( CMAPListener, camel_to_snake, - client_context, - get_pool, - get_pools, wait_until, ) from test.utils_spec_runner import SpecRunnerThread, SpecTestCreator @@ -60,6 +60,8 @@ from pymongo.synchronous.pool import PoolState, _PoolClosedError from pymongo.topology_description import updated_topology_description +_IS_SYNC = True + OBJECT_TYPES = { # Event types. "ConnectionCheckedIn": ConnectionCheckedInEvent, @@ -81,7 +83,10 @@ class TestCMAP(IntegrationTest): # Location of JSON test specifications. - TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "connection_monitoring") + if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "connection_monitoring") + else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "connection_monitoring") # Test operations: @@ -204,15 +209,10 @@ def check_error(self, actual, expected): self.check_object(actual, expected) self.assertIn(message, str(actual)) - def _set_fail_point(self, client, command_args): - cmd = SON([("configureFailPoint", "failCommand")]) - cmd.update(command_args) - client.admin.command(cmd) - def set_fail_point(self, command_args): if not client_context.supports_failCommand_fail_point: self.skipTest("failCommand fail point must be supported") - self._set_fail_point(self.client, command_args) + self.configure_fail_point(self.client, command_args) def run_scenario(self, scenario_def, test): """Run a CMAP spec test.""" @@ -258,7 +258,6 @@ def run_scenario(self, scenario_def, test): client._topology.open() else: client._get_topology() - self.addCleanup(client.close) self.pool = list(client._topology._servers.values())[0].pool # Map of target names to Thread objects. @@ -315,13 +314,11 @@ def cleanup(): # def test_1_client_connection_pool_options(self): client = self.rs_or_single_client(**self.POOL_OPTIONS) - self.addCleanup(client.close) - pool_opts = get_pool(client).opts + pool_opts = (get_pool(client)).opts self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS) def test_2_all_client_pools_have_same_options(self): client = self.rs_or_single_client(**self.POOL_OPTIONS) - self.addCleanup(client.close) client.admin.command("ping") # Discover at least one secondary. if client_context.has_secondaries: @@ -337,14 +334,12 @@ def test_3_uri_connection_pool_options(self): opts = "&".join([f"{k}={v}" for k, v in self.POOL_OPTIONS.items()]) uri = f"mongodb://{client_context.pair}/?{opts}" client = self.rs_or_single_client(uri) - self.addCleanup(client.close) - pool_opts = get_pool(client).opts + pool_opts = (get_pool(client)).opts self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS) def test_4_subscribe_to_events(self): listener = CMAPListener() client = self.single_client(event_listeners=[listener]) - self.addCleanup(client.close) self.assertEqual(listener.event_count(PoolCreatedEvent), 1) # Creates a new connection. @@ -368,7 +363,6 @@ def test_4_subscribe_to_events(self): def test_5_check_out_fails_connection_error(self): listener = CMAPListener() client = self.single_client(event_listeners=[listener]) - self.addCleanup(client.close) pool = get_pool(client) def mock_connect(*args, **kwargs): @@ -397,7 +391,6 @@ def test_5_check_out_fails_auth_error(self): client = self.single_client_noauth( username="notauser", password="fail", event_listeners=[listener] ) - self.addCleanup(client.close) # Attempt to create a new connection. with self.assertRaisesRegex(OperationFailure, "failed"): diff --git a/test/test_connections_survive_primary_stepdown_spec.py b/test/test_connections_survive_primary_stepdown_spec.py index 9cac633301..d923a477b5 100644 --- a/test/test_connections_survive_primary_stepdown_spec.py +++ b/test/test_connections_survive_primary_stepdown_spec.py @@ -16,6 +16,7 @@ from __future__ import annotations import sys +from test.utils import ensure_all_connected sys.path[0:0] = [""] @@ -25,9 +26,8 @@ unittest, ) from test.helpers import repl_set_step_down -from test.utils import ( +from test.utils_shared import ( CMAPListener, - ensure_all_connected, ) from bson import SON diff --git a/test/test_csot.py b/test/test_csot.py index c075a07d5a..5201156a1d 100644 --- a/test/test_csot.py +++ b/test/test_csot.py @@ -17,6 +17,7 @@ import os import sys +from pathlib import Path sys.path[0:0] = [""] @@ -27,8 +28,13 @@ from pymongo import _csot from pymongo.errors import PyMongoError +_IS_SYNC = True + # Location of JSON test specifications. -TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "csot") +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "csot") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "csot") # Generate unified tests. globals().update(generate_test_classes(TEST_PATH, module=__name__)) diff --git a/test/test_cursor.py b/test/test_cursor.py index 84e431f8cb..7b75f4ddc4 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -31,7 +31,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import ( +from test.utils_shared import ( AllowListEventListener, EventListener, OvertCommandListener, @@ -1801,6 +1801,7 @@ def test_monitoring(self): @client_context.require_version_min(5, 0, -1) @client_context.require_no_mongos + @client_context.require_sync def test_exhaust_cursor_db_set(self): listener = OvertCommandListener() client = self.rs_or_single_client(event_listeners=[listener]) @@ -1810,7 +1811,7 @@ def test_exhaust_cursor_db_set(self): listener.reset() - result = c.find({}, cursor_type=pymongo.CursorType.EXHAUST, batch_size=1).to_list() + result = list(c.find({}, cursor_type=pymongo.CursorType.EXHAUST, batch_size=1)) self.assertEqual(len(result), 3) diff --git a/test/test_custom_types.py b/test/test_custom_types.py index 6771ea25f9..08e2a46f8f 100644 --- a/test/test_custom_types.py +++ b/test/test_custom_types.py @@ -23,10 +23,11 @@ from random import random from typing import Any, Tuple, Type, no_type_check +from gridfs.synchronous.grid_file import GridIn, GridOut + sys.path[0:0] = [""] -from test import client_context, unittest -from test.test_client import IntegrationTest +from test import IntegrationTest, client_context, unittest from bson import ( _BUILT_IN_TYPES, @@ -50,10 +51,12 @@ from bson.errors import InvalidDocument from bson.int64 import Int64 from bson.raw_bson import RawBSONDocument -from gridfs import GridIn, GridOut from pymongo.errors import DuplicateKeyError from pymongo.message import _CursorAddress from pymongo.synchronous.collection import ReturnDocument +from pymongo.synchronous.helpers import next + +_IS_SYNC = True class DecimalEncoder(TypeEncoder): @@ -707,7 +710,7 @@ def test_aggregate_w_custom_type_decoder(self): ] result = test.aggregate(pipeline) - res = list(result)[0] + res = (result.to_list())[0] self.assertEqual(res["_id"], "complete") self.assertIsInstance(res["total_qty"], UndecipherableInt64Type) self.assertEqual(res["total_qty"].value, 20) @@ -774,6 +777,7 @@ def test_grid_out_custom_opts(self): one.close() two = GridOut(db.fs, 5) + two.open() self.assertEqual("my_file", two.name) self.assertEqual("my_file", two.filename) @@ -970,7 +974,6 @@ def create_targets(self, *args, **kwargs): kwargs["type_registry"] = codec_options.type_registry kwargs["document_class"] = codec_options.document_class self.watched_target = self.rs_client(*args, **kwargs) - self.addCleanup(self.watched_target.close) self.input_target = self.watched_target[self.db.name].test # Insert a record to ensure db, coll are created. self.input_target.insert_one({"data": "dummy"}) diff --git a/test/test_data_lake.py b/test/test_data_lake.py index a374db550e..d6d2007007 100644 --- a/test/test_data_lake.py +++ b/test/test_data_lake.py @@ -23,25 +23,21 @@ sys.path[0:0] = [""] -from test import IntegrationTest, client_context, unittest +from test import IntegrationTest, UnitTest, client_context, unittest from test.unified_format import generate_test_classes -from test.utils import ( +from test.utils_shared import ( OvertCommandListener, ) -pytestmark = pytest.mark.data_lake +from pymongo.synchronous.helpers import next +_IS_SYNC = True -# Location of JSON test specifications. -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data_lake") +pytestmark = pytest.mark.data_lake -class TestDataLakeMustConnect(unittest.TestCase): +class TestDataLakeMustConnect(UnitTest): def test_connected_to_data_lake(self): - data_lake = os.environ.get("TEST_DATA_LAKE") - if not data_lake: - self.skipTest("TEST_DATA_LAKE is not set") - self.assertTrue( client_context.is_data_lake and client_context.connected, "client context must be connected to data lake when DATA_LAKE is set. Failed attempts:\n{}".format( @@ -55,10 +51,9 @@ class TestDataLakeProse(IntegrationTest): TEST_DB = "test" TEST_COLLECTION = "driverdata" - @classmethod @client_context.require_data_lake - def setUpClass(cls): - super().setUpClass() + def setUp(self): + super().setUp() # Test killCursors def test_1(self): @@ -99,7 +94,10 @@ def test_3(self): # Location of JSON test specifications. -TEST_PATH = Path(__file__).parent / "data_lake/unified" +if _IS_SYNC: + TEST_PATH = Path(__file__).parent / "data_lake/unified" +else: + TEST_PATH = Path(__file__).parent.parent / "data_lake/unified" # Generate unified tests. globals().update(generate_test_classes(TEST_PATH, module=__name__)) diff --git a/test/test_database.py b/test/test_database.py index aad9089bd8..4c09b421cf 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -25,7 +25,7 @@ from test import IntegrationTest, client_context, unittest from test.test_custom_types import DECIMAL_CODECOPTS -from test.utils import ( +from test.utils_shared import ( IMPOSSIBLE_WRITE_CONCERN, OvertCommandListener, wait_until, @@ -425,6 +425,21 @@ def test_command_with_regex(self): for doc in result["cursor"]["firstBatch"]: self.assertTrue(isinstance(doc["r"], Regex)) + def test_command_bulkWrite(self): + # Ensure bulk write commands can be run directly via db.command(). + if client_context.version.at_least(8, 0): + self.client.admin.command( + { + "bulkWrite": 1, + "nsInfo": [{"ns": self.db.test.full_name}], + "ops": [{"insert": 0, "document": {}}], + } + ) + self.db.command({"insert": "test", "documents": [{}]}) + self.db.command({"update": "test", "updates": [{"q": {}, "u": {"$set": {"x": 1}}}]}) + self.db.command({"delete": "test", "deletes": [{"q": {}, "limit": 1}]}) + self.db.test.drop() + def test_cursor_command(self): db = self.client.pymongo_test db.test.drop() diff --git a/test/test_default_exports.py b/test/test_default_exports.py index d9301d2223..adc3882a36 100644 --- a/test/test_default_exports.py +++ b/test/test_default_exports.py @@ -209,6 +209,19 @@ def test_pymongo_imports(self): ) from pymongo.write_concern import WriteConcern, validate_boolean + def test_pymongo_submodule_attributes(self): + import pymongo + + self.assertTrue(hasattr(pymongo, "uri_parser")) + self.assertTrue(pymongo.uri_parser) + self.assertTrue(pymongo.uri_parser.parse_uri) + self.assertTrue(pymongo.change_stream) + self.assertTrue(pymongo.client_session) + self.assertTrue(pymongo.collection) + self.assertTrue(pymongo.cursor) + self.assertTrue(pymongo.command_cursor) + self.assertTrue(pymongo.database) + def test_gridfs_imports(self): import gridfs from gridfs.errors import CorruptGridFile, FileExists, GridFSError, NoFile diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index ce7a52f1a0..9d6c945707 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -15,30 +15,48 @@ """Test the topology module.""" from __future__ import annotations +import asyncio import os import socketserver import sys import threading +import time +from asyncio import StreamReader, StreamWriter +from pathlib import Path +from test.helpers import ConcurrentRunner + +from pymongo.operations import _Op +from pymongo.server_selectors import writable_server_selector +from pymongo.synchronous.pool import Connection sys.path[0:0] = [""] -from test import IntegrationTest, PyMongoTestCase, unittest +from test import ( + IntegrationTest, + PyMongoTestCase, + UnitTest, + client_context, + unittest, +) from test.pymongo_mocks import DummyMonitor from test.unified_format import generate_test_classes from test.utils import ( + get_pool, +) +from test.utils_shared import ( CMAPListener, HeartbeatEventListener, HeartbeatEventsListListener, assertion_context, - client_context, - get_pool, + barrier_wait, + create_barrier, server_name_to_type, wait_until, ) from unittest.mock import patch from bson import Timestamp, json_util -from pymongo import MongoClient, common, monitoring +from pymongo import common, monitoring from pymongo.errors import ( AutoReconnect, ConfigurationError, @@ -52,11 +70,19 @@ from pymongo.server_description import SERVER_TYPE, ServerDescription from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology, _ErrorContext +from pymongo.synchronous.uri_parser import parse_uri from pymongo.topology_description import TOPOLOGY_TYPE -from pymongo.uri_parser import parse_uri + +_IS_SYNC = True # Location of JSON test specifications. -SDAM_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "discovery_and_monitoring") +if _IS_SYNC: + SDAM_PATH = os.path.join(Path(__file__).resolve().parent, "discovery_and_monitoring") +else: + SDAM_PATH = os.path.join( + Path(__file__).resolve().parent.parent, + "discovery_and_monitoring", + ) def create_mock_topology(uri, monitor_class=DummyMonitor): @@ -128,7 +154,7 @@ def get_type(topology, hostname): return description.server_type -class TestAllScenarios(unittest.TestCase): +class TestAllScenarios(UnitTest): pass @@ -166,6 +192,9 @@ def check_outcome(self, topology, outcome): server_type_name(expected_server_type), server_type_name(actual_server_description.server_type), ) + expected_error = expected_server.get("error") + if expected_error: + self.assertIn(expected_error, str(actual_server_description.error)) self.assertEqual(expected_server.get("setName"), actual_server_description.replica_set_name) @@ -240,11 +269,11 @@ def create_tests(): create_tests() -class TestClusterTimeComparison(unittest.TestCase): +class TestClusterTimeComparison(PyMongoTestCase): def test_cluster_time_comparison(self): t = create_mock_topology("mongodb://host") - def send_cluster_time(time, inc, should_update): + def send_cluster_time(time, inc): old = t.max_cluster_time() new = {"clusterTime": Timestamp(time, inc)} got_hello( @@ -259,34 +288,33 @@ def send_cluster_time(time, inc, should_update): ) actual = t.max_cluster_time() - if should_update: - self.assertEqual(actual, new) - else: - self.assertEqual(actual, old) + # We never update $clusterTime from monitoring connections. + self.assertEqual(actual, old) - send_cluster_time(0, 1, True) - send_cluster_time(2, 2, True) - send_cluster_time(2, 1, False) - send_cluster_time(1, 3, False) - send_cluster_time(2, 3, True) + send_cluster_time(0, 1) + send_cluster_time(2, 2) + send_cluster_time(2, 1) + send_cluster_time(1, 3) + send_cluster_time(2, 3) class TestIgnoreStaleErrors(IntegrationTest): def test_ignore_stale_connection_errors(self): - N_THREADS = 5 - barrier = threading.Barrier(N_THREADS, timeout=30) - client = self.rs_or_single_client(minPoolSize=N_THREADS) - self.addCleanup(client.close) + if not _IS_SYNC and sys.version_info < (3, 11): + self.skipTest("Test requires asyncio.Barrier (added in Python 3.11)") + N_TASKS = 5 + barrier = create_barrier(N_TASKS) + client = self.rs_or_single_client(minPoolSize=N_TASKS) # Wait for initial discovery. client.admin.command("ping") pool = get_pool(client) starting_generation = pool.gen.get_overall() - wait_until(lambda: len(pool.conns) == N_THREADS, "created conns") + wait_until(lambda: len(pool.conns) == N_TASKS, "created conns") def mock_command(*args, **kwargs): - # Synchronize all threads to ensure they use the same generation. - barrier.wait() + # Synchronize all tasks to ensure they use the same generation. + barrier_wait(barrier, timeout=30) raise AutoReconnect("mock Connection.command error") for conn in pool.conns: @@ -298,12 +326,12 @@ def insert_command(i): except AutoReconnect: pass - threads = [] - for i in range(N_THREADS): - threads.append(threading.Thread(target=insert_command, args=(i,))) - for t in threads: + tasks = [] + for i in range(N_TASKS): + tasks.append(ConcurrentRunner(target=insert_command, args=(i,))) + for t in tasks: t.start() - for t in threads: + for t in tasks: t.join() # Expect a single pool reset for the network error @@ -322,10 +350,9 @@ class TestPoolManagement(IntegrationTest): def test_pool_unpause(self): # This test implements the prose test "Connection Pool Management" listener = CMAPHeartbeatListener() - client = self.single_client( + _ = self.single_client( appName="SDAMPoolManagementTest", heartbeatFrequencyMS=500, event_listeners=[listener] ) - self.addCleanup(client.close) # Assert that ConnectionPoolReadyEvent occurs after the first # ServerHeartbeatSucceededEvent. listener.wait_for_event(monitoring.PoolReadyEvent, 1) @@ -348,6 +375,72 @@ def test_pool_unpause(self): listener.wait_for_event(monitoring.ServerHeartbeatSucceededEvent, 1) listener.wait_for_event(monitoring.PoolReadyEvent, 1) + @client_context.require_failCommand_appName + @client_context.require_test_commands + @client_context.require_async + def test_connection_close_does_not_block_other_operations(self): + listener = CMAPHeartbeatListener() + client = self.single_client( + appName="SDAMConnectionCloseTest", + event_listeners=[listener], + heartbeatFrequencyMS=500, + minPoolSize=10, + ) + server = (client._get_topology()).select_server(writable_server_selector, _Op.TEST) + wait_until( + lambda: len(server._pool.conns) == 10, + "pool initialized with 10 connections", + ) + + client.db.test.insert_one({"x": 1}) + close_delay = 0.1 + latencies = [] + should_exit = [] + + def run_task(): + while True: + start_time = time.monotonic() + client.db.test.find_one({}) + elapsed = time.monotonic() - start_time + latencies.append(elapsed) + if should_exit: + break + time.sleep(0.001) + + task = ConcurrentRunner(target=run_task) + task.start() + original_close = Connection.close_conn + try: + # Artificially delay the close operation to simulate a slow close + def mock_close(self, reason): + time.sleep(close_delay) + original_close(self, reason) + + Connection.close_conn = mock_close + + fail_hello = { + "mode": {"times": 4}, + "data": { + "failCommands": [HelloCompat.LEGACY_CMD, "hello"], + "errorCode": 91, + "appName": "SDAMConnectionCloseTest", + }, + } + with self.fail_point(fail_hello): + # Wait for server heartbeat to fail + listener.wait_for_event(monitoring.ServerHeartbeatFailedEvent, 1) + # Wait until all idle connections are closed to simulate real-world conditions + listener.wait_for_event(monitoring.ConnectionClosedEvent, 10) + # Wait for one more find to complete after the pool has been reset, then shutdown the task + n = len(latencies) + wait_until(lambda: len(latencies) >= n + 1, "run one more find") + should_exit.append(True) + task.join() + # No operation latency should not significantly exceed close_delay + self.assertLessEqual(max(latencies), close_delay * 5.0) + finally: + Connection.close_conn = original_close + class TestServerMonitoringMode(IntegrationTest): @client_context.require_no_serverless @@ -357,7 +450,6 @@ def setUp(self): def test_rtt_connection_is_enabled_stream(self): client = self.rs_or_single_client(serverMonitoringMode="stream") - self.addCleanup(client.close) client.admin.command("ping") def predicate(): @@ -366,18 +458,26 @@ def predicate(): if not monitor._stream: return False if client_context.version >= (4, 4): - if monitor._rtt_monitor._executor._thread is None: - return False + if _IS_SYNC: + if monitor._rtt_monitor._executor._thread is None: + return False + else: + if monitor._rtt_monitor._executor._task is None: + return False else: - if monitor._rtt_monitor._executor._thread is not None: - return False + if _IS_SYNC: + if monitor._rtt_monitor._executor._thread is not None: + return False + else: + if monitor._rtt_monitor._executor._task is not None: + return False return True wait_until(predicate, "find all RTT monitors") def test_rtt_connection_is_disabled_poll(self): client = self.rs_or_single_client(serverMonitoringMode="poll") - self.addCleanup(client.close) + self.assert_rtt_connection_is_disabled(client) def test_rtt_connection_is_disabled_auto(self): @@ -391,7 +491,6 @@ def test_rtt_connection_is_disabled_auto(self): for env in envs: with patch.dict("os.environ", env): client = self.rs_or_single_client(serverMonitoringMode="auto") - self.addCleanup(client.close) self.assert_rtt_connection_is_disabled(client) def assert_rtt_connection_is_disabled(self, client): @@ -399,7 +498,10 @@ def assert_rtt_connection_is_disabled(self, client): for _, server in client._topology._servers.items(): monitor = server._monitor self.assertFalse(monitor._stream) - self.assertIsNone(monitor._rtt_monitor._executor._thread) + if _IS_SYNC: + self.assertIsNone(monitor._rtt_monitor._executor._thread) + else: + self.assertIsNone(monitor._rtt_monitor._executor._task) class MockTCPHandler(socketserver.BaseRequestHandler): @@ -422,16 +524,46 @@ class TestHeartbeatStartOrdering(PyMongoTestCase): def test_heartbeat_start_ordering(self): events = [] listener = HeartbeatEventsListListener(events) - server = TCPServer(("localhost", 9999), MockTCPHandler) - server.events = events - server_thread = threading.Thread(target=server.handle_request_and_shutdown) - server_thread.start() - _c = self.simple_client( - "mongodb://localhost:9999", serverSelectionTimeoutMS=500, event_listeners=(listener,) - ) - server_thread.join() - listener.wait_for_event(ServerHeartbeatStartedEvent, 1) - listener.wait_for_event(ServerHeartbeatFailedEvent, 1) + + if _IS_SYNC: + server = TCPServer(("localhost", 9999), MockTCPHandler) + server.events = events + server_thread = ConcurrentRunner(target=server.handle_request_and_shutdown) + server_thread.start() + _c = self.simple_client( + "mongodb://localhost:9999", + serverSelectionTimeoutMS=500, + event_listeners=(listener,), + ) + server_thread.join() + listener.wait_for_event(ServerHeartbeatStartedEvent, 1) + listener.wait_for_event(ServerHeartbeatFailedEvent, 1) + + else: + + def handle_client(reader: StreamReader, writer: StreamWriter): + events.append("client connected") + if (reader.read(1024)).strip(): + events.append("client hello received") + writer.close() + writer.wait_closed() + + server = asyncio.start_server(handle_client, "localhost", 9999) + server.events = events + server.start_serving() + _c = self.simple_client( + "mongodb://localhost:9999", + serverSelectionTimeoutMS=500, + event_listeners=(listener,), + ) + _c._connect() + + listener.wait_for_event(ServerHeartbeatStartedEvent, 1) + listener.wait_for_event(ServerHeartbeatFailedEvent, 1) + + server.close() + server.wait_closed() + _c.close() self.assertEqual( events, diff --git a/test/test_dns.py b/test/test_dns.py index f2185efb1b..8f88562e3f 100644 --- a/test/test_dns.py +++ b/test/test_dns.py @@ -18,22 +18,37 @@ import glob import json import os +import pathlib import sys sys.path[0:0] = [""] -from test import IntegrationTest, PyMongoTestCase, client_context, unittest -from test.utils import wait_until +from test import ( + IntegrationTest, + PyMongoTestCase, + client_context, + unittest, +) +from test.utils_shared import wait_until +from unittest.mock import MagicMock, patch from pymongo.common import validate_read_preference_tags from pymongo.errors import ConfigurationError -from pymongo.uri_parser import parse_uri, split_hosts +from pymongo.synchronous.uri_parser import parse_uri +from pymongo.uri_parser_shared import split_hosts + +_IS_SYNC = True class TestDNSRepl(PyMongoTestCase): - TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "srv_seedlist", "replica-set" - ) + if _IS_SYNC: + TEST_PATH = os.path.join( + pathlib.Path(__file__).resolve().parent, "srv_seedlist", "replica-set" + ) + else: + TEST_PATH = os.path.join( + pathlib.Path(__file__).resolve().parent.parent, "srv_seedlist", "replica-set" + ) load_balanced = False @client_context.require_replica_set @@ -42,9 +57,14 @@ def setUp(self): class TestDNSLoadBalanced(PyMongoTestCase): - TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "srv_seedlist", "load-balanced" - ) + if _IS_SYNC: + TEST_PATH = os.path.join( + pathlib.Path(__file__).resolve().parent, "srv_seedlist", "load-balanced" + ) + else: + TEST_PATH = os.path.join( + pathlib.Path(__file__).resolve().parent.parent, "srv_seedlist", "load-balanced" + ) load_balanced = True @client_context.require_load_balancer @@ -53,7 +73,12 @@ def setUp(self): class TestDNSSharded(PyMongoTestCase): - TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "srv_seedlist", "sharded") + if _IS_SYNC: + TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "srv_seedlist", "sharded") + else: + TEST_PATH = os.path.join( + pathlib.Path(__file__).resolve().parent.parent, "srv_seedlist", "sharded" + ) load_balanced = False @client_context.require_mongos @@ -119,7 +144,9 @@ def run_test(self): # tests. copts["tlsAllowInvalidHostnames"] = True - client = PyMongoTestCase.unmanaged_simple_client(uri, **copts) + client = self.simple_client(uri, **copts) + if client._options.connect: + client._connect() if num_seeds is not None: self.assertEqual(len(client._topology_settings.seeds), num_seeds) if hosts is not None: @@ -132,7 +159,6 @@ def run_test(self): client.admin.command("ping") # XXX: we should block until SRV poller runs at least once # and re-run these assertions. - client.close() else: try: parse_uri(uri) @@ -159,38 +185,122 @@ def create_tests(cls): class TestParsingErrors(PyMongoTestCase): def test_invalid_host(self): - self.assertRaisesRegex( - ConfigurationError, - "Invalid URI host: mongodb is not", - self.simple_client, - "mongodb+srv://mongodb", - ) - self.assertRaisesRegex( - ConfigurationError, - "Invalid URI host: mongodb.com is not", - self.simple_client, - "mongodb+srv://mongodb.com", - ) - self.assertRaisesRegex( - ConfigurationError, - "Invalid URI host: an IP address is not", - self.simple_client, - "mongodb+srv://127.0.0.1", - ) - self.assertRaisesRegex( - ConfigurationError, - "Invalid URI host: an IP address is not", - self.simple_client, - "mongodb+srv://[::1]", - ) + with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: an IP address is not"): + client = self.simple_client("mongodb+srv://127.0.0.1") + client._connect() + with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: an IP address is not"): + client = self.simple_client("mongodb+srv://[::1]") + client._connect() class TestCaseInsensitive(IntegrationTest): def test_connect_case_insensitive(self): client = self.simple_client("mongodb+srv://TEST1.TEST.BUILD.10GEN.cc/") - self.addCleanup(client.close) + client._connect() self.assertGreater(len(client.topology_description.server_descriptions()), 1) +class TestInitialDnsSeedlistDiscovery(PyMongoTestCase): + """ + Initial DNS Seedlist Discovery prose tests + https://github.com/mongodb/specifications/blob/0a7a8b5/source/initial-dns-seedlist-discovery/tests/README.md#prose-tests + """ + + def run_initial_dns_seedlist_discovery_prose_tests(self, test_cases): + for case in test_cases: + with patch("dns.resolver.resolve") as mock_resolver: + + def mock_resolve(query, record_type, *args, **kwargs): + mock_srv = MagicMock() + mock_srv.target.to_text.return_value = case["mock_target"] + return [mock_srv] + + mock_resolver.side_effect = mock_resolve + domain = case["query"].split("._tcp.")[1] + connection_string = f"mongodb+srv://{domain}" + if "expected_error" not in case: + parse_uri(connection_string) + else: + try: + parse_uri(connection_string) + except ConfigurationError as e: + self.assertIn(case["expected_error"], str(e)) + else: + self.fail(f"ConfigurationError was not raised for query: {case['query']}") + + def test_1_allow_srv_hosts_with_fewer_than_three_dot_separated_parts(self): + with patch("dns.resolver.resolve"): + parse_uri("mongodb+srv://localhost/") + parse_uri("mongodb+srv://mongo.local/") + + def test_2_throw_when_return_address_does_not_end_with_srv_domain(self): + test_cases = [ + { + "query": "_mongodb._tcp.localhost", + "mock_target": "localhost.mongodb", + "expected_error": "Invalid SRV host", + }, + { + "query": "_mongodb._tcp.blogs.mongodb.com", + "mock_target": "blogs.evil.com", + "expected_error": "Invalid SRV host", + }, + { + "query": "_mongodb._tcp.blogs.mongo.local", + "mock_target": "test_1.evil.com", + "expected_error": "Invalid SRV host", + }, + ] + self.run_initial_dns_seedlist_discovery_prose_tests(test_cases) + + def test_3_throw_when_return_address_is_identical_to_srv_hostname(self): + test_cases = [ + { + "query": "_mongodb._tcp.localhost", + "mock_target": "localhost", + "expected_error": "Invalid SRV host", + }, + { + "query": "_mongodb._tcp.mongo.local", + "mock_target": "mongo.local", + "expected_error": "Invalid SRV host", + }, + ] + self.run_initial_dns_seedlist_discovery_prose_tests(test_cases) + + def test_4_throw_when_return_address_does_not_contain_dot_separating_shared_part_of_domain( + self + ): + test_cases = [ + { + "query": "_mongodb._tcp.localhost", + "mock_target": "test_1.cluster_1localhost", + "expected_error": "Invalid SRV host", + }, + { + "query": "_mongodb._tcp.mongo.local", + "mock_target": "test_1.my_hostmongo.local", + "expected_error": "Invalid SRV host", + }, + { + "query": "_mongodb._tcp.blogs.mongodb.com", + "mock_target": "cluster.testmongodb.com", + "expected_error": "Invalid SRV host", + }, + ] + self.run_initial_dns_seedlist_discovery_prose_tests(test_cases) + + def test_5_when_srv_hostname_has_two_dot_separated_parts_it_is_valid_for_the_returned_hostname_to_be_identical( + self + ): + test_cases = [ + { + "query": "_mongodb._tcp.blogs.mongodb.com", + "mock_target": "blogs.mongodb.com", + }, + ] + self.run_initial_dns_seedlist_discovery_prose_tests(test_cases) + + if __name__ == "__main__": unittest.main() diff --git a/test/test_encryption.py b/test/test_encryption.py index 9224310144..4b055b68d3 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -41,6 +41,7 @@ from pymongo.daemon import _spawn_daemon from pymongo.synchronous.collection import Collection from pymongo.synchronous.helpers import next +from pymongo.uri_parser_shared import _parse_kms_tls_options try: from pymongo.pyopenssl_context import IS_PYOPENSSL @@ -63,7 +64,7 @@ ) from test.test_bulk import BulkTestBase from test.unified_format import generate_test_classes -from test.utils import ( +from test.utils_shared import ( AllowListEventListener, OvertCommandListener, TopologyEventListener, @@ -73,7 +74,7 @@ ) from test.utils_spec_runner import SpecRunner -from bson import DatetimeMS, Decimal128, encode, json_util +from bson import BSON, DatetimeMS, Decimal128, encode, json_util from bson.binary import UUID_SUBTYPE, Binary, UuidRepresentation from bson.codec_options import CodecOptions from bson.errors import BSONError @@ -91,6 +92,7 @@ EncryptionError, InvalidOperation, OperationFailure, + PyMongoError, ServerSelectionTimeoutError, WriteError, ) @@ -140,7 +142,7 @@ def test_init(self): self.assertEqual(opts._mongocryptd_bypass_spawn, False) self.assertEqual(opts._mongocryptd_spawn_path, "mongocryptd") self.assertEqual(opts._mongocryptd_spawn_args, ["--idleShutdownTimeoutSecs=60"]) - self.assertEqual(opts._kms_ssl_contexts, {}) + self.assertEqual(opts._kms_tls_options, None) @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") def test_init_spawn_args(self): @@ -166,28 +168,36 @@ def test_init_spawn_args(self): @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") def test_init_kms_tls_options(self): # Error cases: + opts = AutoEncryptionOpts({}, "k.d", kms_tls_options={"kmip": 1}) with self.assertRaisesRegex(TypeError, r'kms_tls_options\["kmip"\] must be a dict'): - AutoEncryptionOpts({}, "k.d", kms_tls_options={"kmip": 1}) + MongoClient(auto_encryption_opts=opts) + tls_opts: Any for tls_opts in [ {"kmip": {"tls": True, "tlsInsecure": True}}, {"kmip": {"tls": True, "tlsAllowInvalidCertificates": True}}, {"kmip": {"tls": True, "tlsAllowInvalidHostnames": True}}, ]: + opts = AutoEncryptionOpts({}, "k.d", kms_tls_options=tls_opts) with self.assertRaisesRegex(ConfigurationError, "Insecure TLS options prohibited"): - opts = AutoEncryptionOpts({}, "k.d", kms_tls_options=tls_opts) + MongoClient(auto_encryption_opts=opts) + opts = AutoEncryptionOpts( + {}, "k.d", kms_tls_options={"kmip": {"tlsCAFile": "does-not-exist"}} + ) with self.assertRaises(FileNotFoundError): - AutoEncryptionOpts({}, "k.d", kms_tls_options={"kmip": {"tlsCAFile": "does-not-exist"}}) + MongoClient(auto_encryption_opts=opts) # Success cases: tls_opts: Any for tls_opts in [None, {}]: opts = AutoEncryptionOpts({}, "k.d", kms_tls_options=tls_opts) - self.assertEqual(opts._kms_ssl_contexts, {}) + kms_tls_contexts = _parse_kms_tls_options(opts._kms_tls_options, _IS_SYNC) + self.assertEqual(kms_tls_contexts, {}) opts = AutoEncryptionOpts({}, "k.d", kms_tls_options={"kmip": {"tls": True}, "aws": {}}) - ctx = opts._kms_ssl_contexts["kmip"] + _kms_ssl_contexts = _parse_kms_tls_options(opts._kms_tls_options, _IS_SYNC) + ctx = _kms_ssl_contexts["kmip"] self.assertEqual(ctx.check_hostname, True) self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED) - ctx = opts._kms_ssl_contexts["aws"] + ctx = _kms_ssl_contexts["aws"] self.assertEqual(ctx.check_hostname, True) self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED) opts = AutoEncryptionOpts( @@ -195,7 +205,8 @@ def test_init_kms_tls_options(self): "k.d", kms_tls_options={"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM}}, ) - ctx = opts._kms_ssl_contexts["kmip"] + _kms_ssl_contexts = _parse_kms_tls_options(opts._kms_tls_options, _IS_SYNC) + ctx = _kms_ssl_contexts["kmip"] self.assertEqual(ctx.check_hostname, True) self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED) @@ -2216,7 +2227,7 @@ def test_05_tlsDisableOCSPEndpointCheck_is_permitted(self): encryption = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=options ) - ctx = encryption._io_callbacks.opts._kms_ssl_contexts["aws"] + ctx = encryption._io_callbacks._kms_ssl_contexts["aws"] if not hasattr(ctx, "check_ocsp_endpoint"): raise self.skipTest("OCSP not enabled") self.assertFalse(ctx.check_ocsp_endpoint) @@ -2403,6 +2414,310 @@ def test_05_roundtrip_encrypted_unindexed(self): self.assertEqual(decrypted, val) +# https://github.com/mongodb/specifications/blob/527e22d5090ec48bf1e144c45fc831de0f1935f6/source/client-side-encryption/tests/README.md#25-test-lookup +class TestLookupProse(EncryptionIntegrationTest): + @client_context.require_no_standalone + @client_context.require_version_min(7, 0, -1) + def setUp(self): + super().setUp() + encrypted_client = self.rs_or_single_client( + auto_encryption_opts=AutoEncryptionOpts( + key_vault_namespace="db.keyvault", + kms_providers={"local": {"key": LOCAL_MASTER_KEY}}, + ) + ) + encrypted_client.drop_database("db") + + key_doc = json_data("etc", "data", "lookup", "key-doc.json") + create_key_vault(encrypted_client.db.keyvault, key_doc) + self.addCleanup(client_context.client.drop_database, "db") + + encrypted_client.db.create_collection( + "csfle", + validator={"$jsonSchema": json_data("etc", "data", "lookup", "schema-csfle.json")}, + ) + encrypted_client.db.create_collection( + "csfle2", + validator={"$jsonSchema": json_data("etc", "data", "lookup", "schema-csfle2.json")}, + ) + encrypted_client.db.create_collection( + "qe", encryptedFields=json_data("etc", "data", "lookup", "schema-qe.json") + ) + encrypted_client.db.create_collection( + "qe2", encryptedFields=json_data("etc", "data", "lookup", "schema-qe2.json") + ) + encrypted_client.db.create_collection("no_schema") + encrypted_client.db.create_collection("no_schema2") + + unencrypted_client = self.rs_or_single_client() + + encrypted_client.db.csfle.insert_one({"csfle": "csfle"}) + doc = unencrypted_client.db.csfle.find_one() + self.assertTrue(isinstance(doc["csfle"], Binary)) + encrypted_client.db.csfle2.insert_one({"csfle2": "csfle2"}) + doc = unencrypted_client.db.csfle2.find_one() + self.assertTrue(isinstance(doc["csfle2"], Binary)) + encrypted_client.db.qe.insert_one({"qe": "qe"}) + doc = unencrypted_client.db.qe.find_one() + self.assertTrue(isinstance(doc["qe"], Binary)) + encrypted_client.db.qe2.insert_one({"qe2": "qe2"}) + doc = unencrypted_client.db.qe2.find_one() + self.assertTrue(isinstance(doc["qe2"], Binary)) + encrypted_client.db.no_schema.insert_one({"no_schema": "no_schema"}) + encrypted_client.db.no_schema2.insert_one({"no_schema2": "no_schema2"}) + + encrypted_client.close() + unencrypted_client.close() + + @client_context.require_version_min(8, 1, -1) + def test_1_csfle_joins_no_schema(self): + encrypted_client = self.rs_or_single_client( + auto_encryption_opts=AutoEncryptionOpts( + key_vault_namespace="db.keyvault", + kms_providers={"local": {"key": LOCAL_MASTER_KEY}}, + ) + ) + doc = next( + encrypted_client.db.csfle.aggregate( + [ + {"$match": {"csfle": "csfle"}}, + { + "$lookup": { + "from": "no_schema", + "as": "matched", + "pipeline": [ + {"$match": {"no_schema": "no_schema"}}, + {"$project": {"_id": 0}}, + ], + } + }, + {"$project": {"_id": 0}}, + ] + ) + ) + self.assertEqual(doc, {"csfle": "csfle", "matched": [{"no_schema": "no_schema"}]}) + + @client_context.require_version_min(8, 1, -1) + def test_2_qe_joins_no_schema(self): + encrypted_client = self.rs_or_single_client( + auto_encryption_opts=AutoEncryptionOpts( + key_vault_namespace="db.keyvault", + kms_providers={"local": {"key": LOCAL_MASTER_KEY}}, + ) + ) + doc = next( + encrypted_client.db.qe.aggregate( + [ + {"$match": {"qe": "qe"}}, + { + "$lookup": { + "from": "no_schema", + "as": "matched", + "pipeline": [ + {"$match": {"no_schema": "no_schema"}}, + {"$project": {"_id": 0, "__safeContent__": 0}}, + ], + } + }, + {"$project": {"_id": 0, "__safeContent__": 0}}, + ] + ) + ) + self.assertEqual(doc, {"qe": "qe", "matched": [{"no_schema": "no_schema"}]}) + + @client_context.require_version_min(8, 1, -1) + def test_3_no_schema_joins_csfle(self): + encrypted_client = self.rs_or_single_client( + auto_encryption_opts=AutoEncryptionOpts( + key_vault_namespace="db.keyvault", + kms_providers={"local": {"key": LOCAL_MASTER_KEY}}, + ) + ) + doc = next( + encrypted_client.db.no_schema.aggregate( + [ + {"$match": {"no_schema": "no_schema"}}, + { + "$lookup": { + "from": "csfle", + "as": "matched", + "pipeline": [{"$match": {"csfle": "csfle"}}, {"$project": {"_id": 0}}], + } + }, + {"$project": {"_id": 0}}, + ] + ) + ) + self.assertEqual(doc, {"no_schema": "no_schema", "matched": [{"csfle": "csfle"}]}) + + @client_context.require_version_min(8, 1, -1) + def test_4_no_schema_joins_qe(self): + encrypted_client = self.rs_or_single_client( + auto_encryption_opts=AutoEncryptionOpts( + key_vault_namespace="db.keyvault", + kms_providers={"local": {"key": LOCAL_MASTER_KEY}}, + ) + ) + doc = next( + encrypted_client.db.no_schema.aggregate( + [ + {"$match": {"no_schema": "no_schema"}}, + { + "$lookup": { + "from": "qe", + "as": "matched", + "pipeline": [ + {"$match": {"qe": "qe"}}, + {"$project": {"_id": 0, "__safeContent__": 0}}, + ], + } + }, + {"$project": {"_id": 0}}, + ] + ) + ) + self.assertEqual(doc, {"no_schema": "no_schema", "matched": [{"qe": "qe"}]}) + + @client_context.require_version_min(8, 1, -1) + def test_5_csfle_joins_csfle2(self): + encrypted_client = self.rs_or_single_client( + auto_encryption_opts=AutoEncryptionOpts( + key_vault_namespace="db.keyvault", + kms_providers={"local": {"key": LOCAL_MASTER_KEY}}, + ) + ) + doc = next( + encrypted_client.db.csfle.aggregate( + [ + {"$match": {"csfle": "csfle"}}, + { + "$lookup": { + "from": "csfle2", + "as": "matched", + "pipeline": [ + {"$match": {"csfle2": "csfle2"}}, + {"$project": {"_id": 0}}, + ], + } + }, + {"$project": {"_id": 0}}, + ] + ) + ) + self.assertEqual(doc, {"csfle": "csfle", "matched": [{"csfle2": "csfle2"}]}) + + @client_context.require_version_min(8, 1, -1) + def test_6_qe_joins_qe2(self): + encrypted_client = self.rs_or_single_client( + auto_encryption_opts=AutoEncryptionOpts( + key_vault_namespace="db.keyvault", + kms_providers={"local": {"key": LOCAL_MASTER_KEY}}, + ) + ) + doc = next( + encrypted_client.db.qe.aggregate( + [ + {"$match": {"qe": "qe"}}, + { + "$lookup": { + "from": "qe2", + "as": "matched", + "pipeline": [ + {"$match": {"qe2": "qe2"}}, + {"$project": {"_id": 0, "__safeContent__": 0}}, + ], + } + }, + {"$project": {"_id": 0, "__safeContent__": 0}}, + ] + ) + ) + self.assertEqual(doc, {"qe": "qe", "matched": [{"qe2": "qe2"}]}) + + @client_context.require_version_min(8, 1, -1) + def test_7_no_schema_joins_no_schema2(self): + encrypted_client = self.rs_or_single_client( + auto_encryption_opts=AutoEncryptionOpts( + key_vault_namespace="db.keyvault", + kms_providers={"local": {"key": LOCAL_MASTER_KEY}}, + ) + ) + doc = next( + encrypted_client.db.no_schema.aggregate( + [ + {"$match": {"no_schema": "no_schema"}}, + { + "$lookup": { + "from": "no_schema2", + "as": "matched", + "pipeline": [ + {"$match": {"no_schema2": "no_schema2"}}, + {"$project": {"_id": 0}}, + ], + } + }, + {"$project": {"_id": 0}}, + ] + ) + ) + self.assertEqual(doc, {"no_schema": "no_schema", "matched": [{"no_schema2": "no_schema2"}]}) + + @client_context.require_version_min(8, 1, -1) + def test_8_csfle_joins_qe(self): + encrypted_client = self.rs_or_single_client( + auto_encryption_opts=AutoEncryptionOpts( + key_vault_namespace="db.keyvault", + kms_providers={"local": {"key": LOCAL_MASTER_KEY}}, + ) + ) + with self.assertRaises(PyMongoError) as exc: + _ = next( + encrypted_client.db.csfle.aggregate( + [ + {"$match": {"csfle": "qe"}}, + { + "$lookup": { + "from": "qe", + "as": "matched", + "pipeline": [{"$match": {"qe": "qe"}}, {"$project": {"_id": 0}}], + } + }, + {"$project": {"_id": 0}}, + ] + ) + ) + self.assertIn("not supported", str(exc)) + + @client_context.require_version_max(8, 1, -1) + def test_9_error(self): + encrypted_client = self.rs_or_single_client( + auto_encryption_opts=AutoEncryptionOpts( + key_vault_namespace="db.keyvault", + kms_providers={"local": {"key": LOCAL_MASTER_KEY}}, + ) + ) + with self.assertRaises(PyMongoError) as exc: + _ = next( + encrypted_client.db.csfle.aggregate( + [ + {"$match": {"csfle": "csfle"}}, + { + "$lookup": { + "from": "no_schema", + "as": "matched", + "pipeline": [ + {"$match": {"no_schema": "no_schema"}}, + {"$project": {"_id": 0}}, + ], + } + }, + {"$project": {"_id": 0}}, + ] + ) + ) + self.assertIn("Upgrade", str(exc)) + + # https://github.com/mongodb/specifications/blob/072601/source/client-side-encryption/tests/README.md#rewrap class TestRewrapWithSeparateClientEncryption(EncryptionIntegrationTest): MASTER_KEYS: Mapping[str, Mapping[str, Any]] = { @@ -2964,9 +3279,10 @@ def test_02_no_fields(self): ) def test_03_invalid_keyid(self): + # checkAuthForCreateCollection can be removed when SERVER-102101 is fixed. with self.assertRaisesRegex( EncryptedCollectionError, - "create.encryptedFields.fields.keyId' is the wrong type 'bool', expected type 'binData", + "(create|checkAuthForCreateCollection).encryptedFields.fields.keyId' is the wrong type 'bool', expected type 'binData", ): self.client_encryption.create_encrypted_collection( database=self.db, diff --git a/test/test_examples.py b/test/test_examples.py index 7f98226e7a..28fe1beaff 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -15,22 +15,29 @@ """MongoDB documentation examples in Python.""" from __future__ import annotations +import asyncio import datetime +import functools import sys import threading +import time +from test.helpers import ConcurrentRunner sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import wait_until +from test.utils_shared import wait_until import pymongo from pymongo.errors import ConnectionFailure, OperationFailure from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference from pymongo.server_api import ServerApi +from pymongo.synchronous.helpers import next from pymongo.write_concern import WriteConcern +_IS_SYNC = True + class TestSampleShellCommands(IntegrationTest): def setUp(self): @@ -62,7 +69,7 @@ def test_first_three_examples(self): cursor = db.inventory.find({"item": "canvas"}) # End Example 2 - self.assertEqual(len(list(cursor)), 1) + self.assertEqual(len(cursor.to_list()), 1) # Start Example 3 db.inventory.insert_many( @@ -137,31 +144,31 @@ def test_query_top_level_fields(self): cursor = db.inventory.find({}) # End Example 7 - self.assertEqual(len(list(cursor)), 5) + self.assertEqual(len(cursor.to_list()), 5) # Start Example 9 cursor = db.inventory.find({"status": "D"}) # End Example 9 - self.assertEqual(len(list(cursor)), 2) + self.assertEqual(len(cursor.to_list()), 2) # Start Example 10 cursor = db.inventory.find({"status": {"$in": ["A", "D"]}}) # End Example 10 - self.assertEqual(len(list(cursor)), 5) + self.assertEqual(len(cursor.to_list()), 5) # Start Example 11 cursor = db.inventory.find({"status": "A", "qty": {"$lt": 30}}) # End Example 11 - self.assertEqual(len(list(cursor)), 1) + self.assertEqual(len(cursor.to_list()), 1) # Start Example 12 cursor = db.inventory.find({"$or": [{"status": "A"}, {"qty": {"$lt": 30}}]}) # End Example 12 - self.assertEqual(len(list(cursor)), 3) + self.assertEqual(len(cursor.to_list()), 3) # Start Example 13 cursor = db.inventory.find( @@ -169,7 +176,7 @@ def test_query_top_level_fields(self): ) # End Example 13 - self.assertEqual(len(list(cursor)), 2) + self.assertEqual(len(cursor.to_list()), 2) def test_query_embedded_documents(self): db = self.db @@ -219,31 +226,31 @@ def test_query_embedded_documents(self): cursor = db.inventory.find({"size": SON([("h", 14), ("w", 21), ("uom", "cm")])}) # End Example 15 - self.assertEqual(len(list(cursor)), 1) + self.assertEqual(len(cursor.to_list()), 1) # Start Example 16 cursor = db.inventory.find({"size": SON([("w", 21), ("h", 14), ("uom", "cm")])}) # End Example 16 - self.assertEqual(len(list(cursor)), 0) + self.assertEqual(len(cursor.to_list()), 0) # Start Example 17 cursor = db.inventory.find({"size.uom": "in"}) # End Example 17 - self.assertEqual(len(list(cursor)), 2) + self.assertEqual(len(cursor.to_list()), 2) # Start Example 18 cursor = db.inventory.find({"size.h": {"$lt": 15}}) # End Example 18 - self.assertEqual(len(list(cursor)), 4) + self.assertEqual(len(cursor.to_list()), 4) # Start Example 19 cursor = db.inventory.find({"size.h": {"$lt": 15}, "size.uom": "in", "status": "D"}) # End Example 19 - self.assertEqual(len(list(cursor)), 1) + self.assertEqual(len(cursor.to_list()), 1) def test_query_arrays(self): db = self.db @@ -269,49 +276,49 @@ def test_query_arrays(self): cursor = db.inventory.find({"tags": ["red", "blank"]}) # End Example 21 - self.assertEqual(len(list(cursor)), 1) + self.assertEqual(len(cursor.to_list()), 1) # Start Example 22 cursor = db.inventory.find({"tags": {"$all": ["red", "blank"]}}) # End Example 22 - self.assertEqual(len(list(cursor)), 4) + self.assertEqual(len(cursor.to_list()), 4) # Start Example 23 cursor = db.inventory.find({"tags": "red"}) # End Example 23 - self.assertEqual(len(list(cursor)), 4) + self.assertEqual(len(cursor.to_list()), 4) # Start Example 24 cursor = db.inventory.find({"dim_cm": {"$gt": 25}}) # End Example 24 - self.assertEqual(len(list(cursor)), 1) + self.assertEqual(len(cursor.to_list()), 1) # Start Example 25 cursor = db.inventory.find({"dim_cm": {"$gt": 15, "$lt": 20}}) # End Example 25 - self.assertEqual(len(list(cursor)), 4) + self.assertEqual(len(cursor.to_list()), 4) # Start Example 26 cursor = db.inventory.find({"dim_cm": {"$elemMatch": {"$gt": 22, "$lt": 30}}}) # End Example 26 - self.assertEqual(len(list(cursor)), 1) + self.assertEqual(len(cursor.to_list()), 1) # Start Example 27 cursor = db.inventory.find({"dim_cm.1": {"$gt": 25}}) # End Example 27 - self.assertEqual(len(list(cursor)), 1) + self.assertEqual(len(cursor.to_list()), 1) # Start Example 28 cursor = db.inventory.find({"tags": {"$size": 3}}) # End Example 28 - self.assertEqual(len(list(cursor)), 1) + self.assertEqual(len(cursor.to_list()), 1) def test_query_array_of_documents(self): db = self.db @@ -360,49 +367,49 @@ def test_query_array_of_documents(self): cursor = db.inventory.find({"instock": SON([("warehouse", "A"), ("qty", 5)])}) # End Example 30 - self.assertEqual(len(list(cursor)), 1) + self.assertEqual(len(cursor.to_list()), 1) # Start Example 31 cursor = db.inventory.find({"instock": SON([("qty", 5), ("warehouse", "A")])}) # End Example 31 - self.assertEqual(len(list(cursor)), 0) + self.assertEqual(len(cursor.to_list()), 0) # Start Example 32 cursor = db.inventory.find({"instock.0.qty": {"$lte": 20}}) # End Example 32 - self.assertEqual(len(list(cursor)), 3) + self.assertEqual(len(cursor.to_list()), 3) # Start Example 33 cursor = db.inventory.find({"instock.qty": {"$lte": 20}}) # End Example 33 - self.assertEqual(len(list(cursor)), 5) + self.assertEqual(len(cursor.to_list()), 5) # Start Example 34 cursor = db.inventory.find({"instock": {"$elemMatch": {"qty": 5, "warehouse": "A"}}}) # End Example 34 - self.assertEqual(len(list(cursor)), 1) + self.assertEqual(len(cursor.to_list()), 1) # Start Example 35 cursor = db.inventory.find({"instock": {"$elemMatch": {"qty": {"$gt": 10, "$lte": 20}}}}) # End Example 35 - self.assertEqual(len(list(cursor)), 3) + self.assertEqual(len(cursor.to_list()), 3) # Start Example 36 cursor = db.inventory.find({"instock.qty": {"$gt": 10, "$lte": 20}}) # End Example 36 - self.assertEqual(len(list(cursor)), 4) + self.assertEqual(len(cursor.to_list()), 4) # Start Example 37 cursor = db.inventory.find({"instock.qty": 5, "instock.warehouse": "A"}) # End Example 37 - self.assertEqual(len(list(cursor)), 2) + self.assertEqual(len(cursor.to_list()), 2) def test_query_null(self): db = self.db @@ -415,19 +422,19 @@ def test_query_null(self): cursor = db.inventory.find({"item": None}) # End Example 39 - self.assertEqual(len(list(cursor)), 2) + self.assertEqual(len(cursor.to_list()), 2) # Start Example 40 cursor = db.inventory.find({"item": {"$type": 10}}) # End Example 40 - self.assertEqual(len(list(cursor)), 1) + self.assertEqual(len(cursor.to_list()), 1) # Start Example 41 cursor = db.inventory.find({"item": {"$exists": False}}) # End Example 41 - self.assertEqual(len(list(cursor)), 1) + self.assertEqual(len(cursor.to_list()), 1) def test_projection(self): db = self.db @@ -473,7 +480,7 @@ def test_projection(self): cursor = db.inventory.find({"status": "A"}) # End Example 43 - self.assertEqual(len(list(cursor)), 3) + self.assertEqual(len(cursor.to_list()), 3) # Start Example 44 cursor = db.inventory.find({"status": "A"}, {"item": 1, "status": 1}) @@ -746,8 +753,9 @@ def insert_docs(): while not done: db.inventory.insert_one({"username": "alice"}) db.inventory.delete_one({"username": "alice"}) + time.sleep(0.005) - t = threading.Thread(target=insert_docs) + t = ConcurrentRunner(target=insert_docs) t.start() try: @@ -1153,12 +1161,7 @@ def callback(session): # Step 2: Start a client session. with client.start_session() as session: # Step 3: Use with_transaction to start a transaction, execute the callback, and commit (or abort on error). - session.with_transaction( - callback, - read_concern=ReadConcern("local"), - write_concern=wc_majority, - read_preference=ReadPreference.PRIMARY, - ) + session.with_transaction(callback) # End Transactions withTxn API Example 1 @@ -1347,20 +1350,37 @@ def test_snapshot_query(self): db.drop_collection("dogs") db.cats.insert_one({"name": "Whiskers", "color": "white", "age": 10, "adoptable": True}) db.dogs.insert_one({"name": "Pebbles", "color": "Brown", "age": 10, "adoptable": True}) - wait_until(lambda: self.check_for_snapshot(db.cats), "success") - wait_until(lambda: self.check_for_snapshot(db.dogs), "success") + + def predicate_one(): + return self.check_for_snapshot(db.cats) + + def predicate_two(): + return self.check_for_snapshot(db.dogs) + + wait_until(predicate_two, "success") + wait_until(predicate_one, "success") # Start Snapshot Query Example 1 db = client.pets with client.start_session(snapshot=True) as s: - adoptablePetsCount = db.cats.aggregate( - [{"$match": {"adoptable": True}}, {"$count": "adoptableCatsCount"}], session=s - ).next()["adoptableCatsCount"] - - adoptablePetsCount += db.dogs.aggregate( - [{"$match": {"adoptable": True}}, {"$count": "adoptableDogsCount"}], session=s - ).next()["adoptableDogsCount"] + adoptablePetsCount = ( + ( + db.cats.aggregate( + [{"$match": {"adoptable": True}}, {"$count": "adoptableCatsCount"}], + session=s, + ) + ).next() + )["adoptableCatsCount"] + + adoptablePetsCount += ( + ( + db.dogs.aggregate( + [{"$match": {"adoptable": True}}, {"$count": "adoptableDogsCount"}], + session=s, + ) + ).next() + )["adoptableDogsCount"] print(adoptablePetsCount) @@ -1371,33 +1391,41 @@ def test_snapshot_query(self): saleDate = datetime.datetime.now() db.sales.insert_one({"shoeType": "boot", "price": 30, "saleDate": saleDate}) - wait_until(lambda: self.check_for_snapshot(db.sales), "success") + + def predicate_three(): + return self.check_for_snapshot(db.sales) + + wait_until(predicate_three, "success") # Start Snapshot Query Example 2 db = client.retail with client.start_session(snapshot=True) as s: - db.sales.aggregate( - [ - { - "$match": { - "$expr": { - "$gt": [ - "$saleDate", - { - "$dateSubtract": { - "startDate": "$$NOW", - "unit": "day", - "amount": 1, - } - }, - ] - } - } - }, - {"$count": "totalDailySales"}, - ], - session=s, - ).next()["totalDailySales"] + _ = ( + ( + db.sales.aggregate( + [ + { + "$match": { + "$expr": { + "$gt": [ + "$saleDate", + { + "$dateSubtract": { + "startDate": "$$NOW", + "unit": "day", + "amount": 1, + } + }, + ] + } + } + }, + {"$count": "totalDailySales"}, + ], + session=s, + ) + ).next() + )["totalDailySales"] # End Snapshot Query Example 2 diff --git a/test/test_fork.py b/test/test_fork.py index 1a89159435..fe88d778d2 100644 --- a/test/test_fork.py +++ b/test/test_fork.py @@ -24,7 +24,7 @@ sys.path[0:0] = [""] from test import IntegrationTest -from test.utils import is_greenthread_patched +from test.utils_shared import is_greenthread_patched from bson.objectid import ObjectId diff --git a/test/test_grid_file.py b/test/test_grid_file.py index 6534bc11bf..0baeb5ae19 100644 --- a/test/test_grid_file.py +++ b/test/test_grid_file.py @@ -33,7 +33,7 @@ sys.path[0:0] = [""] -from test.utils import OvertCommandListener +from test.utils_shared import OvertCommandListener from bson.objectid import ObjectId from gridfs.errors import NoFile diff --git a/test/test_gridfs.py b/test/test_gridfs.py index ab8950250b..75342ee437 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -16,17 +16,20 @@ """Tests for the gridfs package.""" from __future__ import annotations +import asyncio import datetime import sys import threading import time from io import BytesIO +from test.helpers import ConcurrentRunner from unittest.mock import patch sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import joinall, one +from test.utils import joinall +from test.utils_shared import one import gridfs from bson.binary import Binary @@ -41,10 +44,12 @@ from pymongo.synchronous.database import Database from pymongo.synchronous.mongo_client import MongoClient +_IS_SYNC = True -class JustWrite(threading.Thread): + +class JustWrite(ConcurrentRunner): def __init__(self, fs, n): - threading.Thread.__init__(self) + super().__init__() self.fs = fs self.n = n self.daemon = True @@ -56,9 +61,9 @@ def run(self): file.close() -class JustRead(threading.Thread): +class JustRead(ConcurrentRunner): def __init__(self, fs, n, results): - threading.Thread.__init__(self) + super().__init__() self.fs = fs self.n = n self.results = results @@ -98,19 +103,21 @@ def setUp(self): def test_basic(self): oid = self.fs.put(b"hello world") - self.assertEqual(b"hello world", self.fs.get(oid).read()) + self.assertEqual(b"hello world", (self.fs.get(oid)).read()) self.assertEqual(1, self.db.fs.files.count_documents({})) self.assertEqual(1, self.db.fs.chunks.count_documents({})) self.fs.delete(oid) - self.assertRaises(NoFile, self.fs.get, oid) + with self.assertRaises(NoFile): + self.fs.get(oid) self.assertEqual(0, self.db.fs.files.count_documents({})) self.assertEqual(0, self.db.fs.chunks.count_documents({})) - self.assertRaises(NoFile, self.fs.get, "foo") + with self.assertRaises(NoFile): + self.fs.get("foo") oid = self.fs.put(b"hello world", _id="foo") self.assertEqual("foo", oid) - self.assertEqual(b"hello world", self.fs.get("foo").read()) + self.assertEqual(b"hello world", (self.fs.get("foo")).read()) def test_multi_chunk_delete(self): self.db.fs.drop() @@ -142,7 +149,7 @@ def test_list(self): def test_empty_file(self): oid = self.fs.put(b"") - self.assertEqual(b"", self.fs.get(oid).read()) + self.assertEqual(b"", (self.fs.get(oid)).read()) self.assertEqual(1, self.db.fs.files.count_documents({})) self.assertEqual(0, self.db.fs.chunks.count_documents({})) @@ -159,10 +166,12 @@ def test_corrupt_chunk(self): self.db.fs.chunks.update_one({"files_id": files_id}, {"$set": {"data": Binary(b"foo", 0)}}) try: out = self.fs.get(files_id) - self.assertRaises(CorruptGridFile, out.read) + with self.assertRaises(CorruptGridFile): + out.read() out = self.fs.get(files_id) - self.assertRaises(CorruptGridFile, out.readline) + with self.assertRaises(CorruptGridFile): + out.readline() finally: self.fs.delete(files_id) @@ -177,31 +186,33 @@ def test_put_ensures_index(self): self.assertTrue( any( info.get("key") == [("files_id", 1), ("n", 1)] - for info in chunks.index_information().values() + for info in (chunks.index_information()).values() ) ) self.assertTrue( any( info.get("key") == [("filename", 1), ("uploadDate", 1)] - for info in files.index_information().values() + for info in (files.index_information()).values() ) ) def test_alt_collection(self): oid = self.alt.put(b"hello world") - self.assertEqual(b"hello world", self.alt.get(oid).read()) + self.assertEqual(b"hello world", (self.alt.get(oid)).read()) self.assertEqual(1, self.db.alt.files.count_documents({})) self.assertEqual(1, self.db.alt.chunks.count_documents({})) self.alt.delete(oid) - self.assertRaises(NoFile, self.alt.get, oid) + with self.assertRaises(NoFile): + self.alt.get(oid) self.assertEqual(0, self.db.alt.files.count_documents({})) self.assertEqual(0, self.db.alt.chunks.count_documents({})) - self.assertRaises(NoFile, self.alt.get, "foo") + with self.assertRaises(NoFile): + self.alt.get("foo") oid = self.alt.put(b"hello world", _id="foo") self.assertEqual("foo", oid) - self.assertEqual(b"hello world", self.alt.get("foo").read()) + self.assertEqual(b"hello world", (self.alt.get("foo")).read()) self.alt.put(b"", filename="mike") self.alt.put(b"foo", filename="test") @@ -212,23 +223,23 @@ def test_alt_collection(self): def test_threaded_reads(self): self.fs.put(b"hello", _id="test") - threads = [] + tasks = [] results: list = [] for i in range(10): - threads.append(JustRead(self.fs, 10, results)) - threads[i].start() + tasks.append(JustRead(self.fs, 10, results)) + tasks[i].start() - joinall(threads) + joinall(tasks) self.assertEqual(100 * [b"hello"], results) def test_threaded_writes(self): - threads = [] + tasks = [] for i in range(10): - threads.append(JustWrite(self.fs, 10)) - threads[i].start() + tasks.append(JustWrite(self.fs, 10)) + tasks[i].start() - joinall(threads) + joinall(tasks) f = self.fs.get_last_version("test") self.assertEqual(f.read(), b"hello") @@ -246,34 +257,37 @@ def test_get_last_version(self): two = two._id three = self.fs.put(b"baz", filename="test") - self.assertEqual(b"baz", self.fs.get_last_version("test").read()) + self.assertEqual(b"baz", (self.fs.get_last_version("test")).read()) self.fs.delete(three) - self.assertEqual(b"bar", self.fs.get_last_version("test").read()) + self.assertEqual(b"bar", (self.fs.get_last_version("test")).read()) self.fs.delete(two) - self.assertEqual(b"foo", self.fs.get_last_version("test").read()) + self.assertEqual(b"foo", (self.fs.get_last_version("test")).read()) self.fs.delete(one) - self.assertRaises(NoFile, self.fs.get_last_version, "test") + with self.assertRaises(NoFile): + self.fs.get_last_version("test") def test_get_last_version_with_metadata(self): one = self.fs.put(b"foo", filename="test", author="author") time.sleep(0.01) two = self.fs.put(b"bar", filename="test", author="author") - self.assertEqual(b"bar", self.fs.get_last_version(author="author").read()) + self.assertEqual(b"bar", (self.fs.get_last_version(author="author")).read()) self.fs.delete(two) - self.assertEqual(b"foo", self.fs.get_last_version(author="author").read()) + self.assertEqual(b"foo", (self.fs.get_last_version(author="author")).read()) self.fs.delete(one) one = self.fs.put(b"foo", filename="test", author="author1") time.sleep(0.01) two = self.fs.put(b"bar", filename="test", author="author2") - self.assertEqual(b"foo", self.fs.get_last_version(author="author1").read()) - self.assertEqual(b"bar", self.fs.get_last_version(author="author2").read()) - self.assertEqual(b"bar", self.fs.get_last_version(filename="test").read()) + self.assertEqual(b"foo", (self.fs.get_last_version(author="author1")).read()) + self.assertEqual(b"bar", (self.fs.get_last_version(author="author2")).read()) + self.assertEqual(b"bar", (self.fs.get_last_version(filename="test")).read()) - self.assertRaises(NoFile, self.fs.get_last_version, author="author3") - self.assertRaises(NoFile, self.fs.get_last_version, filename="nottest", author="author1") + with self.assertRaises(NoFile): + self.fs.get_last_version(author="author3") + with self.assertRaises(NoFile): + self.fs.get_last_version(filename="nottest", author="author1") self.fs.delete(one) self.fs.delete(two) @@ -286,16 +300,18 @@ def test_get_version(self): self.fs.put(b"baz", filename="test") time.sleep(0.01) - self.assertEqual(b"foo", self.fs.get_version("test", 0).read()) - self.assertEqual(b"bar", self.fs.get_version("test", 1).read()) - self.assertEqual(b"baz", self.fs.get_version("test", 2).read()) + self.assertEqual(b"foo", (self.fs.get_version("test", 0)).read()) + self.assertEqual(b"bar", (self.fs.get_version("test", 1)).read()) + self.assertEqual(b"baz", (self.fs.get_version("test", 2)).read()) - self.assertEqual(b"baz", self.fs.get_version("test", -1).read()) - self.assertEqual(b"bar", self.fs.get_version("test", -2).read()) - self.assertEqual(b"foo", self.fs.get_version("test", -3).read()) + self.assertEqual(b"baz", (self.fs.get_version("test", -1)).read()) + self.assertEqual(b"bar", (self.fs.get_version("test", -2)).read()) + self.assertEqual(b"foo", (self.fs.get_version("test", -3)).read()) - self.assertRaises(NoFile, self.fs.get_version, "test", 3) - self.assertRaises(NoFile, self.fs.get_version, "test", -4) + with self.assertRaises(NoFile): + self.fs.get_version("test", 3) + with self.assertRaises(NoFile): + self.fs.get_version("test", -4) def test_get_version_with_metadata(self): one = self.fs.put(b"foo", filename="test", author="author1") @@ -305,25 +321,32 @@ def test_get_version_with_metadata(self): three = self.fs.put(b"baz", filename="test", author="author2") self.assertEqual( - b"foo", self.fs.get_version(filename="test", author="author1", version=-2).read() + b"foo", + (self.fs.get_version(filename="test", author="author1", version=-2)).read(), ) self.assertEqual( - b"bar", self.fs.get_version(filename="test", author="author1", version=-1).read() + b"bar", + (self.fs.get_version(filename="test", author="author1", version=-1)).read(), ) self.assertEqual( - b"foo", self.fs.get_version(filename="test", author="author1", version=0).read() + b"foo", + (self.fs.get_version(filename="test", author="author1", version=0)).read(), ) self.assertEqual( - b"bar", self.fs.get_version(filename="test", author="author1", version=1).read() + b"bar", + (self.fs.get_version(filename="test", author="author1", version=1)).read(), ) self.assertEqual( - b"baz", self.fs.get_version(filename="test", author="author2", version=0).read() + b"baz", + (self.fs.get_version(filename="test", author="author2", version=0)).read(), ) - self.assertEqual(b"baz", self.fs.get_version(filename="test", version=-1).read()) - self.assertEqual(b"baz", self.fs.get_version(filename="test", version=2).read()) + self.assertEqual(b"baz", (self.fs.get_version(filename="test", version=-1)).read()) + self.assertEqual(b"baz", (self.fs.get_version(filename="test", version=2)).read()) - self.assertRaises(NoFile, self.fs.get_version, filename="test", author="author3") - self.assertRaises(NoFile, self.fs.get_version, filename="test", author="author1", version=2) + with self.assertRaises(NoFile): + self.fs.get_version(filename="test", author="author3") + with self.assertRaises(NoFile): + self.fs.get_version(filename="test", author="author1", version=2) self.fs.delete(one) self.fs.delete(two) @@ -332,11 +355,12 @@ def test_get_version_with_metadata(self): def test_put_filelike(self): oid = self.fs.put(BytesIO(b"hello world"), chunk_size=1) self.assertEqual(11, self.db.fs.chunks.count_documents({})) - self.assertEqual(b"hello world", self.fs.get(oid).read()) + self.assertEqual(b"hello world", (self.fs.get(oid)).read()) def test_file_exists(self): oid = self.fs.put(b"hello") - self.assertRaises(FileExists, self.fs.put, b"world", _id=oid) + with self.assertRaises(FileExists): + self.fs.put(b"world", _id=oid) one = self.fs.new_file(_id=123) one.write(b"some content") @@ -345,15 +369,17 @@ def test_file_exists(self): # Attempt to upload a file with more chunks to the same _id. with patch("gridfs.synchronous.grid_file._UPLOAD_BUFFER_SIZE", DEFAULT_CHUNK_SIZE): two = self.fs.new_file(_id=123) - self.assertRaises(FileExists, two.write, b"x" * DEFAULT_CHUNK_SIZE * 3) + with self.assertRaises(FileExists): + two.write(b"x" * DEFAULT_CHUNK_SIZE * 3) # Original file is still readable (no extra chunks were uploaded). - self.assertEqual(self.fs.get(123).read(), b"some content") + self.assertEqual((self.fs.get(123)).read(), b"some content") two = self.fs.new_file(_id=123) two.write(b"some content") - self.assertRaises(FileExists, two.close) + with self.assertRaises(FileExists): + two.close() # Original file is still readable. - self.assertEqual(self.fs.get(123).read(), b"some content") + self.assertEqual((self.fs.get(123)).read(), b"some content") def test_exists(self): oid = self.fs.put(b"hello") @@ -381,15 +407,16 @@ def test_exists(self): self.assertFalse(self.fs.exists({"foo": {"$gt": 12}})) def test_put_unicode(self): - self.assertRaises(TypeError, self.fs.put, "hello") + with self.assertRaises(TypeError): + self.fs.put("hello") oid = self.fs.put("hello", encoding="utf-8") - self.assertEqual(b"hello", self.fs.get(oid).read()) - self.assertEqual("utf-8", self.fs.get(oid).encoding) + self.assertEqual(b"hello", (self.fs.get(oid)).read()) + self.assertEqual("utf-8", (self.fs.get(oid)).encoding) oid = self.fs.put("aé", encoding="iso-8859-1") - self.assertEqual("aé".encode("iso-8859-1"), self.fs.get(oid).read()) - self.assertEqual("iso-8859-1", self.fs.get(oid).encoding) + self.assertEqual("aé".encode("iso-8859-1"), (self.fs.get(oid)).read()) + self.assertEqual("iso-8859-1", (self.fs.get(oid)).encoding) def test_missing_length_iter(self): # Test fix that guards against PHP-237 @@ -411,11 +438,13 @@ def test_gridfs_lazy_connect(self): client = self.single_client("badhost", connect=False, serverSelectionTimeoutMS=10) db = client.db gfs = gridfs.GridFS(db) - self.assertRaises(ServerSelectionTimeoutError, gfs.list) + with self.assertRaises(ServerSelectionTimeoutError): + gfs.list() fs = gridfs.GridFS(db) f = fs.new_file() - self.assertRaises(ServerSelectionTimeoutError, f.close) + with self.assertRaises(ServerSelectionTimeoutError): + f.close() def test_gridfs_find(self): self.fs.put(b"test2", filename="two") @@ -429,14 +458,15 @@ def test_gridfs_find(self): self.assertEqual(3, files.count_documents({"filename": "two"})) self.assertEqual(4, files.count_documents({})) cursor = self.fs.find(no_cursor_timeout=False).sort("uploadDate", -1).skip(1).limit(2) - gout = next(cursor) + gout = cursor.next() self.assertEqual(b"test1", gout.read()) cursor.rewind() - gout = next(cursor) + gout = cursor.next() self.assertEqual(b"test1", gout.read()) - gout = next(cursor) + gout = cursor.next() self.assertEqual(b"test2+", gout.read()) - self.assertRaises(StopIteration, cursor.__next__) + with self.assertRaises(StopIteration): + cursor.__next__() cursor.rewind() items = cursor.to_list() self.assertEqual(len(items), 2) @@ -484,12 +514,12 @@ def test_grid_in_non_int_chunksize(self): self.fs.put(data, filename="f") self.db.fs.files.update_one({"filename": "f"}, {"$set": {"chunkSize": 100.0}}) - self.assertEqual(data, self.fs.get_version("f").read()) + self.assertEqual(data, (self.fs.get_version("f")).read()) def test_unacknowledged(self): # w=0 is prohibited. with self.assertRaises(ConfigurationError): - gridfs.GridFS(self.rs_or_single_client(w=0).pymongo_test) + gridfs.GridFS((self.rs_or_single_client(w=0)).pymongo_test) def test_md5(self): gin = self.fs.new_file() @@ -524,7 +554,7 @@ def test_gridfs_replica_set(self): self.assertEqual(gin._coll.read_preference, ReadPreference.PRIMARY) oid = fs.put(b"foo") - content = fs.get(oid).read() + content = (fs.get(oid)).read() self.assertEqual(b"foo", content) def test_gridfs_secondary(self): @@ -538,7 +568,8 @@ def test_gridfs_secondary(self): fs = gridfs.GridFS(secondary_connection.gfsreplica, "gfssecondarytest") # This won't detect secondary, raises error - self.assertRaises(NotPrimaryError, fs.put, b"foo") + with self.assertRaises(NotPrimaryError): + fs.put(b"foo") def test_gridfs_secondary_lazy(self): # Should detect it's connected to secondary and not attempt to @@ -552,8 +583,10 @@ def test_gridfs_secondary_lazy(self): fs = gridfs.GridFS(client.gfsreplica, "gfssecondarylazytest") # Connects, doesn't create index. - self.assertRaises(NoFile, fs.get_last_version) - self.assertRaises(NotPrimaryError, fs.put, "data", encoding="utf-8") + with self.assertRaises(NoFile): + fs.get_last_version() + with self.assertRaises(NotPrimaryError): + fs.put("data", encoding="utf-8") if __name__ == "__main__": diff --git a/test/test_gridfs_bucket.py b/test/test_gridfs_bucket.py index 0af4dce811..e941369f99 100644 --- a/test/test_gridfs_bucket.py +++ b/test/test_gridfs_bucket.py @@ -16,18 +16,21 @@ """Tests for the gridfs package.""" from __future__ import annotations +import asyncio import datetime import itertools import sys import threading import time from io import BytesIO +from test.helpers import ConcurrentRunner from unittest.mock import patch sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import joinall, one +from test.utils import joinall +from test.utils_shared import one import gridfs from bson.binary import Binary @@ -44,10 +47,12 @@ from pymongo.read_preferences import ReadPreference from pymongo.synchronous.mongo_client import MongoClient +_IS_SYNC = True -class JustWrite(threading.Thread): + +class JustWrite(ConcurrentRunner): def __init__(self, gfs, num): - threading.Thread.__init__(self) + super().__init__() self.gfs = gfs self.num = num self.daemon = True @@ -59,9 +64,9 @@ def run(self): file.close() -class JustRead(threading.Thread): +class JustRead(ConcurrentRunner): def __init__(self, gfs, num, results): - threading.Thread.__init__(self) + super().__init__() self.gfs = gfs self.num = num self.results = results @@ -89,12 +94,13 @@ def setUp(self): def test_basic(self): oid = self.fs.upload_from_stream("test_filename", b"hello world") - self.assertEqual(b"hello world", self.fs.open_download_stream(oid).read()) + self.assertEqual(b"hello world", (self.fs.open_download_stream(oid)).read()) self.assertEqual(1, self.db.fs.files.count_documents({})) self.assertEqual(1, self.db.fs.chunks.count_documents({})) self.fs.delete(oid) - self.assertRaises(NoFile, self.fs.open_download_stream, oid) + with self.assertRaises(NoFile): + self.fs.open_download_stream(oid) self.assertEqual(0, self.db.fs.files.count_documents({})) self.assertEqual(0, self.db.fs.chunks.count_documents({})) @@ -109,9 +115,20 @@ def test_multi_chunk_delete(self): self.assertEqual(0, self.db.fs.files.count_documents({})) self.assertEqual(0, self.db.fs.chunks.count_documents({})) + def test_delete_by_name(self): + self.assertEqual(0, self.db.fs.files.count_documents({})) + self.assertEqual(0, self.db.fs.chunks.count_documents({})) + gfs = gridfs.GridFSBucket(self.db) + gfs.upload_from_stream("test_filename", b"hello", chunk_size_bytes=1) + self.assertEqual(1, self.db.fs.files.count_documents({})) + self.assertEqual(5, self.db.fs.chunks.count_documents({})) + gfs.delete_by_name("test_filename") + self.assertEqual(0, self.db.fs.files.count_documents({})) + self.assertEqual(0, self.db.fs.chunks.count_documents({})) + def test_empty_file(self): oid = self.fs.upload_from_stream("test_filename", b"") - self.assertEqual(b"", self.fs.open_download_stream(oid).read()) + self.assertEqual(b"", (self.fs.open_download_stream(oid)).read()) self.assertEqual(1, self.db.fs.files.count_documents({})) self.assertEqual(0, self.db.fs.chunks.count_documents({})) @@ -128,10 +145,12 @@ def test_corrupt_chunk(self): self.db.fs.chunks.update_one({"files_id": files_id}, {"$set": {"data": Binary(b"foo", 0)}}) try: out = self.fs.open_download_stream(files_id) - self.assertRaises(CorruptGridFile, out.read) + with self.assertRaises(CorruptGridFile): + out.read() out = self.fs.open_download_stream(files_id) - self.assertRaises(CorruptGridFile, out.readline) + with self.assertRaises(CorruptGridFile): + out.readline() finally: self.fs.delete(files_id) @@ -146,13 +165,13 @@ def test_upload_ensures_index(self): self.assertTrue( any( info.get("key") == [("files_id", 1), ("n", 1)] - for info in chunks.index_information().values() + for info in (chunks.index_information()).values() ) ) self.assertTrue( any( info.get("key") == [("filename", 1), ("uploadDate", 1)] - for info in files.index_information().values() + for info in (files.index_information()).values() ) ) @@ -174,25 +193,27 @@ def test_ensure_index_shell_compat(self): self.assertTrue( any( info.get("key") == [("filename", 1), ("uploadDate", 1)] - for info in files.index_information().values() + for info in (files.index_information()).values() ) ) files.drop() def test_alt_collection(self): oid = self.alt.upload_from_stream("test_filename", b"hello world") - self.assertEqual(b"hello world", self.alt.open_download_stream(oid).read()) + self.assertEqual(b"hello world", (self.alt.open_download_stream(oid)).read()) self.assertEqual(1, self.db.alt.files.count_documents({})) self.assertEqual(1, self.db.alt.chunks.count_documents({})) self.alt.delete(oid) - self.assertRaises(NoFile, self.alt.open_download_stream, oid) + with self.assertRaises(NoFile): + self.alt.open_download_stream(oid) self.assertEqual(0, self.db.alt.files.count_documents({})) self.assertEqual(0, self.db.alt.chunks.count_documents({})) - self.assertRaises(NoFile, self.alt.open_download_stream, "foo") + with self.assertRaises(NoFile): + self.alt.open_download_stream("foo") self.alt.upload_from_stream("foo", b"hello world") - self.assertEqual(b"hello world", self.alt.open_download_stream_by_name("foo").read()) + self.assertEqual(b"hello world", (self.alt.open_download_stream_by_name("foo")).read()) self.alt.upload_from_stream("mike", b"") self.alt.upload_from_stream("test", b"foo") @@ -200,7 +221,7 @@ def test_alt_collection(self): self.assertEqual( {"mike", "test", "hello world", "foo"}, - {k["filename"] for k in list(self.db.alt.files.find())}, + {k["filename"] for k in self.db.alt.files.find().to_list()}, ) def test_threaded_reads(self): @@ -240,13 +261,14 @@ def test_get_last_version(self): two = two._id three = self.fs.upload_from_stream("test", b"baz") - self.assertEqual(b"baz", self.fs.open_download_stream_by_name("test").read()) + self.assertEqual(b"baz", (self.fs.open_download_stream_by_name("test")).read()) self.fs.delete(three) - self.assertEqual(b"bar", self.fs.open_download_stream_by_name("test").read()) + self.assertEqual(b"bar", (self.fs.open_download_stream_by_name("test")).read()) self.fs.delete(two) - self.assertEqual(b"foo", self.fs.open_download_stream_by_name("test").read()) + self.assertEqual(b"foo", (self.fs.open_download_stream_by_name("test")).read()) self.fs.delete(one) - self.assertRaises(NoFile, self.fs.open_download_stream_by_name, "test") + with self.assertRaises(NoFile): + self.fs.open_download_stream_by_name("test") def test_get_version(self): self.fs.upload_from_stream("test", b"foo") @@ -256,28 +278,30 @@ def test_get_version(self): self.fs.upload_from_stream("test", b"baz") time.sleep(0.01) - self.assertEqual(b"foo", self.fs.open_download_stream_by_name("test", revision=0).read()) - self.assertEqual(b"bar", self.fs.open_download_stream_by_name("test", revision=1).read()) - self.assertEqual(b"baz", self.fs.open_download_stream_by_name("test", revision=2).read()) + self.assertEqual(b"foo", (self.fs.open_download_stream_by_name("test", revision=0)).read()) + self.assertEqual(b"bar", (self.fs.open_download_stream_by_name("test", revision=1)).read()) + self.assertEqual(b"baz", (self.fs.open_download_stream_by_name("test", revision=2)).read()) - self.assertEqual(b"baz", self.fs.open_download_stream_by_name("test", revision=-1).read()) - self.assertEqual(b"bar", self.fs.open_download_stream_by_name("test", revision=-2).read()) - self.assertEqual(b"foo", self.fs.open_download_stream_by_name("test", revision=-3).read()) + self.assertEqual(b"baz", (self.fs.open_download_stream_by_name("test", revision=-1)).read()) + self.assertEqual(b"bar", (self.fs.open_download_stream_by_name("test", revision=-2)).read()) + self.assertEqual(b"foo", (self.fs.open_download_stream_by_name("test", revision=-3)).read()) - self.assertRaises(NoFile, self.fs.open_download_stream_by_name, "test", revision=3) - self.assertRaises(NoFile, self.fs.open_download_stream_by_name, "test", revision=-4) + with self.assertRaises(NoFile): + self.fs.open_download_stream_by_name("test", revision=3) + with self.assertRaises(NoFile): + self.fs.open_download_stream_by_name("test", revision=-4) def test_upload_from_stream(self): oid = self.fs.upload_from_stream("test_file", BytesIO(b"hello world"), chunk_size_bytes=1) self.assertEqual(11, self.db.fs.chunks.count_documents({})) - self.assertEqual(b"hello world", self.fs.open_download_stream(oid).read()) + self.assertEqual(b"hello world", (self.fs.open_download_stream(oid)).read()) def test_upload_from_stream_with_id(self): oid = ObjectId() self.fs.upload_from_stream_with_id( oid, "test_file_custom_id", BytesIO(b"custom id"), chunk_size_bytes=1 ) - self.assertEqual(b"custom id", self.fs.open_download_stream(oid).read()) + self.assertEqual(b"custom id", (self.fs.open_download_stream(oid)).read()) @patch("gridfs.synchronous.grid_file._UPLOAD_BUFFER_CHUNKS", 3) @client_context.require_failCommand_fail_point @@ -316,14 +340,14 @@ def test_open_upload_stream(self): gin = self.fs.open_upload_stream("from_stream") gin.write(b"from stream") gin.close() - self.assertEqual(b"from stream", self.fs.open_download_stream(gin._id).read()) + self.assertEqual(b"from stream", (self.fs.open_download_stream(gin._id)).read()) def test_open_upload_stream_with_id(self): oid = ObjectId() gin = self.fs.open_upload_stream_with_id(oid, "from_stream_custom_id") gin.write(b"from stream with custom id") gin.close() - self.assertEqual(b"from stream with custom id", self.fs.open_download_stream(oid).read()) + self.assertEqual(b"from stream with custom id", (self.fs.open_download_stream(oid)).read()) def test_missing_length_iter(self): # Test fix that guards against PHP-237 @@ -345,12 +369,12 @@ def test_gridfs_lazy_connect(self): client = self.single_client("badhost", connect=False, serverSelectionTimeoutMS=0) cdb = client.db gfs = gridfs.GridFSBucket(cdb) - self.assertRaises(ServerSelectionTimeoutError, gfs.delete, 0) + with self.assertRaises(ServerSelectionTimeoutError): + gfs.delete(0) gfs = gridfs.GridFSBucket(cdb) - self.assertRaises( - ServerSelectionTimeoutError, gfs.upload_from_stream, "test", b"" - ) # Still no connection. + with self.assertRaises(ServerSelectionTimeoutError): + gfs.upload_from_stream("test", b"") # Still no connection. def test_gridfs_find(self): self.fs.upload_from_stream("two", b"test2") @@ -366,14 +390,15 @@ def test_gridfs_find(self): cursor = self.fs.find( {}, no_cursor_timeout=False, sort=[("uploadDate", -1)], skip=1, limit=2 ) - gout = next(cursor) + gout = cursor.next() self.assertEqual(b"test1", gout.read()) cursor.rewind() - gout = next(cursor) + gout = cursor.next() self.assertEqual(b"test1", gout.read()) - gout = next(cursor) + gout = cursor.next() self.assertEqual(b"test2+", gout.read()) - self.assertRaises(StopIteration, cursor.__next__) + with self.assertRaises(StopIteration): + cursor.next() cursor.close() self.assertRaises(TypeError, self.fs.find, {}, {"_id": True}) @@ -383,20 +408,30 @@ def test_grid_in_non_int_chunksize(self): self.fs.upload_from_stream("f", data) self.db.fs.files.update_one({"filename": "f"}, {"$set": {"chunkSize": 100.0}}) - self.assertEqual(data, self.fs.open_download_stream_by_name("f").read()) + self.assertEqual(data, (self.fs.open_download_stream_by_name("f")).read()) def test_unacknowledged(self): # w=0 is prohibited. with self.assertRaises(ConfigurationError): - gridfs.GridFSBucket(self.rs_or_single_client(w=0).pymongo_test) + gridfs.GridFSBucket((self.rs_or_single_client(w=0)).pymongo_test) def test_rename(self): _id = self.fs.upload_from_stream("first_name", b"testing") - self.assertEqual(b"testing", self.fs.open_download_stream_by_name("first_name").read()) + self.assertEqual(b"testing", (self.fs.open_download_stream_by_name("first_name")).read()) self.fs.rename(_id, "second_name") - self.assertRaises(NoFile, self.fs.open_download_stream_by_name, "first_name") - self.assertEqual(b"testing", self.fs.open_download_stream_by_name("second_name").read()) + with self.assertRaises(NoFile): + self.fs.open_download_stream_by_name("first_name") + self.assertEqual(b"testing", (self.fs.open_download_stream_by_name("second_name")).read()) + + def test_rename_by_name(self): + _id = self.fs.upload_from_stream("first_name", b"testing") + self.assertEqual(b"testing", (self.fs.open_download_stream_by_name("first_name")).read()) + + self.fs.rename_by_name("first_name", "second_name") + with self.assertRaises(NoFile): + self.fs.open_download_stream_by_name("first_name") + self.assertEqual(b"testing", (self.fs.open_download_stream_by_name("second_name")).read()) @patch("gridfs.synchronous.grid_file._UPLOAD_BUFFER_SIZE", 5) def test_abort(self): @@ -407,7 +442,8 @@ def test_abort(self): self.assertEqual(3, self.db.fs.chunks.count_documents({"files_id": gin._id})) gin.abort() self.assertTrue(gin.closed) - self.assertRaises(ValueError, gin.write, b"test4") + with self.assertRaises(ValueError): + gin.write(b"test4") self.assertEqual(0, self.db.fs.chunks.count_documents({"files_id": gin._id})) def test_download_to_stream(self): @@ -490,7 +526,7 @@ def test_gridfs_replica_set(self): gfs = gridfs.GridFSBucket(rsc.gfsbucketreplica, "gfsbucketreplicatest") oid = gfs.upload_from_stream("test_filename", b"foo") - content = gfs.open_download_stream(oid).read() + content = (gfs.open_download_stream(oid)).read() self.assertEqual(b"foo", content) def test_gridfs_secondary(self): @@ -504,7 +540,8 @@ def test_gridfs_secondary(self): gfs = gridfs.GridFSBucket(secondary_connection.gfsbucketreplica, "gfsbucketsecondarytest") # This won't detect secondary, raises error - self.assertRaises(NotPrimaryError, gfs.upload_from_stream, "test_filename", b"foo") + with self.assertRaises(NotPrimaryError): + gfs.upload_from_stream("test_filename", b"foo") def test_gridfs_secondary_lazy(self): # Should detect it's connected to secondary and not attempt to @@ -518,8 +555,10 @@ def test_gridfs_secondary_lazy(self): gfs = gridfs.GridFSBucket(client.gfsbucketreplica, "gfsbucketsecondarylazytest") # Connects, doesn't create index. - self.assertRaises(NoFile, gfs.open_download_stream_by_name, "test_filename") - self.assertRaises(NotPrimaryError, gfs.upload_from_stream, "test_filename", b"data") + with self.assertRaises(NoFile): + gfs.open_download_stream_by_name("test_filename") + with self.assertRaises(NotPrimaryError): + gfs.upload_from_stream("test_filename", b"data") if __name__ == "__main__": diff --git a/test/test_gridfs_spec.py b/test/test_gridfs_spec.py index 6840b6ae0c..e84e19725e 100644 --- a/test/test_gridfs_spec.py +++ b/test/test_gridfs_spec.py @@ -17,14 +17,20 @@ import os import sys +from pathlib import Path sys.path[0:0] = [""] from test import unittest from test.unified_format import generate_test_classes +_IS_SYNC = True + # Location of JSON test specifications. -TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "gridfs") +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "gridfs") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "gridfs") # Generate unified tests. globals().update(generate_test_classes(TEST_PATH, module=__name__)) diff --git a/test/test_heartbeat_monitoring.py b/test/test_heartbeat_monitoring.py index 5e203a33b3..7864caf6e1 100644 --- a/test/test_heartbeat_monitoring.py +++ b/test/test_heartbeat_monitoring.py @@ -16,16 +16,19 @@ from __future__ import annotations import sys +from test.utils import MockPool sys.path[0:0] = [""] from test import IntegrationTest, client_knobs, unittest -from test.utils import HeartbeatEventListener, MockPool, wait_until +from test.utils_shared import HeartbeatEventListener, wait_until from pymongo.errors import ConnectionFailure from pymongo.hello import Hello, HelloCompat from pymongo.synchronous.monitor import Monitor +_IS_SYNC = True + class TestHeartbeatMonitoring(IntegrationTest): def create_mock_monitor(self, responses, uri, expected_results): @@ -40,8 +43,12 @@ def _check_with_socket(self, *args, **kwargs): raise responses[1] return Hello(responses[1]), 99 - m = self.single_client( - h=uri, event_listeners=(listener,), _monitor_class=MockMonitor, _pool_class=MockPool + _ = self.single_client( + h=uri, + event_listeners=(listener,), + _monitor_class=MockMonitor, + _pool_class=MockPool, + connect=True, ) expected_len = len(expected_results) @@ -50,20 +57,16 @@ def _check_with_socket(self, *args, **kwargs): # of this test. wait_until(lambda: len(listener.events) >= expected_len, "publish all events") - try: - # zip gives us len(expected_results) pairs. - for expected, actual in zip(expected_results, listener.events): - self.assertEqual(expected, actual.__class__.__name__) - self.assertEqual(actual.connection_id, responses[0]) - if expected != "ServerHeartbeatStartedEvent": - if isinstance(actual.reply, Hello): - self.assertEqual(actual.duration, 99) - self.assertEqual(actual.reply._doc, responses[1]) - else: - self.assertEqual(actual.reply, responses[1]) - - finally: - m.close() + # zip gives us len(expected_results) pairs. + for expected, actual in zip(expected_results, listener.events): + self.assertEqual(expected, actual.__class__.__name__) + self.assertEqual(actual.connection_id, responses[0]) + if expected != "ServerHeartbeatStartedEvent": + if isinstance(actual.reply, Hello): + self.assertEqual(actual.duration, 99) + self.assertEqual(actual.reply._doc, responses[1]) + else: + self.assertEqual(actual.reply, responses[1]) def test_standalone(self): responses = ( diff --git a/test/test_index_management.py b/test/test_index_management.py index 6ca726e2e0..dea8c0e2be 100644 --- a/test/test_index_management.py +++ b/test/test_index_management.py @@ -15,7 +15,9 @@ """Run the auth spec tests.""" from __future__ import annotations +import asyncio import os +import pathlib import sys import time import uuid @@ -27,24 +29,28 @@ from test import IntegrationTest, PyMongoTestCase, unittest from test.unified_format import generate_test_classes -from test.utils import AllowListEventListener, EventListener, OvertCommandListener +from test.utils_shared import AllowListEventListener, OvertCommandListener from pymongo.errors import OperationFailure from pymongo.operations import SearchIndexModel from pymongo.read_concern import ReadConcern from pymongo.write_concern import WriteConcern -pytestmark = pytest.mark.index_management +_IS_SYNC = True -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "index_management") +pytestmark = pytest.mark.search_index + +# Location of JSON test specifications. +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "index_management") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "index_management") _NAME = "test-search-index" class TestCreateSearchIndex(IntegrationTest): def test_inputs(self): - if not os.environ.get("TEST_INDEX_MANAGEMENT"): - raise unittest.SkipTest("Skipping index management tests") listener = AllowListEventListener("createSearchIndexes") client = self.simple_client(event_listeners=[listener]) coll = client.test.test @@ -82,23 +88,23 @@ class SearchIndexIntegrationBase(PyMongoTestCase): @classmethod def setUpClass(cls) -> None: - super().setUpClass() - if not os.environ.get("TEST_INDEX_MANAGEMENT"): - raise unittest.SkipTest("Skipping index management tests") - url = os.environ.get("MONGODB_URI") - username = os.environ["DB_USER"] - password = os.environ["DB_PASSWORD"] - cls.listener = listener = OvertCommandListener() - cls.client = cls.unmanaged_simple_client( - url, username=username, password=password, event_listeners=[listener] + cls.url = os.environ.get("MONGODB_URI") + cls.username = os.environ["DB_USER"] + cls.password = os.environ["DB_PASSWORD"] + cls.listener = OvertCommandListener() + + def setUp(self) -> None: + self.client = self.simple_client( + self.url, + username=self.username, + password=self.password, + event_listeners=[self.listener], ) - cls.client.drop_database(_NAME) - cls.db = cls.client[cls.db_name] + self.client.drop_database(_NAME) + self.db = self.client[self.db_name] - @classmethod - def tearDownClass(cls): - cls.client.drop_database(_NAME) - cls.client.close() + def tearDown(self): + self.client.drop_database(_NAME) def wait_for_ready(self, coll, name=_NAME, predicate=None): """Wait for a search index to be ready.""" @@ -107,10 +113,9 @@ def wait_for_ready(self, coll, name=_NAME, predicate=None): predicate = lambda index: index.get("queryable") is True while True: - indices = list(coll.list_search_indexes(name)) + indices = (coll.list_search_indexes(name)).to_list() if len(indices) and predicate(indices[0]): return indices[0] - break time.sleep(5) @@ -133,7 +138,7 @@ def test_comment_field(self): # Get the index definition. self.listener.reset() - coll0.list_search_indexes(name=implicit_search_resp, comment="foo").next() + (coll0.list_search_indexes(name=implicit_search_resp, comment="foo")).next() event = self.listener.events[0] self.assertEqual(event.command["comment"], "foo") @@ -183,7 +188,7 @@ def test_case_2(self): ) # .Assert that the command returns an array containing the new indexes' names: ``["test-search-index-1", "test-search-index-2"]``. - indices = list(coll0.list_search_indexes()) + indices = (coll0.list_search_indexes()).to_list() names = [i["name"] for i in indices] self.assertIn(name1, names) self.assertIn(name2, names) @@ -223,7 +228,7 @@ def test_case_3(self): # Run ``coll0.listSearchIndexes()`` repeatedly every 5 seconds until ``listSearchIndexes`` returns an empty array. t0 = time.time() while True: - indices = list(coll0.list_search_indexes()) + indices = (coll0.list_search_indexes()).to_list() if indices: break if (time.time() - t0) / 60 > 5: @@ -259,7 +264,7 @@ def test_case_4(self): self.wait_for_ready(coll0, predicate=predicate) # Assert that an index is present with the name ``test-search-index`` and the definition has a property ``latestDefinition`` whose value is ``{ 'mappings': { 'dynamic': true } }``. - index = list(coll0.list_search_indexes(_NAME))[0] + index = ((coll0.list_search_indexes(_NAME)).to_list())[0] self.assertIn("latestDefinition", index) self.assertEqual(index["latestDefinition"], model2["definition"]) @@ -324,7 +329,7 @@ def test_case_7(self): ) # Get the index definition. - resp = coll0.list_search_indexes(name=implicit_search_resp).next() + resp = (coll0.list_search_indexes(name=implicit_search_resp)).next() # Assert that the index model contains the correct index type: ``"search"``. self.assertEqual(resp["type"], "search") @@ -335,7 +340,7 @@ def test_case_7(self): ) # Get the index definition. - resp = coll0.list_search_indexes(name=explicit_search_resp).next() + resp = (coll0.list_search_indexes(name=explicit_search_resp)).next() # Assert that the index model contains the correct index type: ``"search"``. self.assertEqual(resp["type"], "search") @@ -350,7 +355,7 @@ def test_case_7(self): ) # Get the index definition. - resp = coll0.list_search_indexes(name=explicit_vector_resp).next() + resp = (coll0.list_search_indexes(name=explicit_vector_resp)).next() # Assert that the index model contains the correct index type: ``"vectorSearch"``. self.assertEqual(resp["type"], "vectorSearch") diff --git a/test/test_json_util.py b/test/test_json_util.py index 821ca76da0..8aed4a82bc 100644 --- a/test/test_json_util.py +++ b/test/test_json_util.py @@ -21,13 +21,13 @@ import sys import uuid from collections import OrderedDict -from typing import Any, List, MutableMapping, Tuple, Type +from typing import Any, Tuple, Type from bson.codec_options import CodecOptions, DatetimeConversion sys.path[0:0] = [""] -from test import IntegrationTest, unittest +from test import unittest from bson import EPOCH_AWARE, EPOCH_NAIVE, SON, DatetimeMS, json_util from bson.binary import ( @@ -636,24 +636,5 @@ class MyBinary(Binary): self.assertEqual(json_util.dumps(MyBinary(b"bin", USER_DEFINED_SUBTYPE)), expected_json) -class TestJsonUtilRoundtrip(IntegrationTest): - def test_cursor(self): - db = self.db - - db.drop_collection("test") - docs: List[MutableMapping[str, Any]] = [ - {"foo": [1, 2]}, - {"bar": {"hello": "world"}}, - {"code": Code("function x() { return 1; }")}, - {"bin": Binary(b"\x00\x01\x02\x03\x04", USER_DEFINED_SUBTYPE)}, - {"dbref": {"_ref": DBRef("simple", ObjectId("509b8db456c02c5ab7e63c34"))}}, - ] - - db.test.insert_many(docs) - reloaded_docs = json_util.loads(json_util.dumps(db.test.find())) - for doc in docs: - self.assertTrue(doc in reloaded_docs) - - if __name__ == "__main__": unittest.main() diff --git a/test/test_json_util_integration.py b/test/test_json_util_integration.py new file mode 100644 index 0000000000..acab4f3182 --- /dev/null +++ b/test/test_json_util_integration.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from test import IntegrationTest +from typing import Any, List, MutableMapping + +from bson import Binary, Code, DBRef, ObjectId, json_util +from bson.binary import USER_DEFINED_SUBTYPE + +_IS_SYNC = True + + +class TestJsonUtilRoundtrip(IntegrationTest): + def test_cursor(self): + db = self.db + + db.drop_collection("test") + docs: List[MutableMapping[str, Any]] = [ + {"foo": [1, 2]}, + {"bar": {"hello": "world"}}, + {"code": Code("function x() { return 1; }")}, + {"bin": Binary(b"\x00\x01\x02\x03\x04", USER_DEFINED_SUBTYPE)}, + {"dbref": {"_ref": DBRef("simple", ObjectId("509b8db456c02c5ab7e63c34"))}}, + ] + + db.test.insert_many(docs) + reloaded_docs = json_util.loads(json_util.dumps((db.test.find()).to_list())) + for doc in docs: + self.assertTrue(doc in reloaded_docs) diff --git a/test/test_load_balancer.py b/test/test_load_balancer.py index 23bea4d984..d7f1d596cc 100644 --- a/test/test_load_balancer.py +++ b/test/test_load_balancer.py @@ -15,10 +15,15 @@ """Test the Load Balancer unified spec tests.""" from __future__ import annotations +import asyncio import gc import os +import pathlib import sys import threading +from asyncio import Event +from test.helpers import ConcurrentRunner, ExceptionCatchingTask +from test.utils import get_pool import pytest @@ -26,15 +31,25 @@ from test import IntegrationTest, client_context, unittest from test.unified_format import generate_test_classes -from test.utils import ExceptionCatchingThread, get_pool, wait_until +from test.utils_shared import ( + create_event, + wait_until, +) + +from pymongo.synchronous.helpers import next + +_IS_SYNC = True pytestmark = pytest.mark.load_balancer # Location of JSON test specifications. -TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "load_balancer") +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "load_balancer") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "load_balancer") # Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) +globals().update(generate_test_classes(_TEST_PATH, module=__name__)) class TestLB(IntegrationTest): @@ -49,13 +64,12 @@ def test_connections_are_only_returned_once(self): n_conns = len(pool.conns) self.db.test.find_one({}) self.assertEqual(len(pool.conns), n_conns) - list(self.db.test.aggregate([{"$limit": 1}])) + (self.db.test.aggregate([{"$limit": 1}])).to_list() self.assertEqual(len(pool.conns), n_conns) @client_context.require_load_balancer def test_unpin_committed_transaction(self): client = self.rs_client() - self.addCleanup(client.close) pool = get_pool(client) coll = client[self.db.name].test with client.start_session() as session: @@ -86,7 +100,6 @@ def create_resource(coll): def _test_no_gc_deadlock(self, create_resource): client = self.rs_client() - self.addCleanup(client.close) pool = get_pool(client) coll = client[self.db.name].test coll.insert_many([{} for _ in range(10)]) @@ -104,19 +117,19 @@ def _test_no_gc_deadlock(self, create_resource): if client_context.load_balancer: self.assertEqual(pool.active_sockets, 1) # Pinned. - thread = PoolLocker(pool) - thread.start() - self.assertTrue(thread.locked.wait(5), "timed out") + task = PoolLocker(pool) + task.start() + self.assertTrue(task.wait(task.locked, 5), "timed out") # Garbage collect the resource while the pool is locked to ensure we # don't deadlock. del resource # On PyPy it can take a few rounds to collect the cursor. for _ in range(3): gc.collect() - thread.unlock.set() - thread.join(5) - self.assertFalse(thread.is_alive()) - self.assertIsNone(thread.exc) + task.unlock.set() + task.join(5) + self.assertFalse(task.is_alive()) + self.assertIsNone(task.exc) wait_until(lambda: pool.active_sockets == 0, "return socket") # Run another operation to ensure the socket still works. @@ -125,7 +138,6 @@ def _test_no_gc_deadlock(self, create_resource): @client_context.require_transactions def test_session_gc(self): client = self.rs_client() - self.addCleanup(client.close) pool = get_pool(client) session = client.start_session() session.start_transaction() @@ -137,41 +149,51 @@ def test_session_gc(self): if client_context.load_balancer: self.assertEqual(pool.active_sockets, 1) # Pinned. - thread = PoolLocker(pool) - thread.start() - self.assertTrue(thread.locked.wait(5), "timed out") + task = PoolLocker(pool) + task.start() + self.assertTrue(task.wait(task.locked, 5), "timed out") # Garbage collect the session while the pool is locked to ensure we # don't deadlock. del session # On PyPy it can take a few rounds to collect the session. for _ in range(3): gc.collect() - thread.unlock.set() - thread.join(5) - self.assertFalse(thread.is_alive()) - self.assertIsNone(thread.exc) + task.unlock.set() + task.join(5) + self.assertFalse(task.is_alive()) + self.assertIsNone(task.exc) wait_until(lambda: pool.active_sockets == 0, "return socket") # Run another operation to ensure the socket still works. client[self.db.name].test.delete_many({}) -class PoolLocker(ExceptionCatchingThread): +class PoolLocker(ExceptionCatchingTask): def __init__(self, pool): super().__init__(target=self.lock_pool) self.pool = pool self.daemon = True - self.locked = threading.Event() - self.unlock = threading.Event() + self.locked = create_event() + self.unlock = create_event() def lock_pool(self): with self.pool.lock: self.locked.set() # Wait for the unlock flag. - unlock_pool = self.unlock.wait(10) + unlock_pool = self.wait(self.unlock, 10) if not unlock_pool: raise Exception("timed out waiting for unlock signal: deadlock?") + def wait(self, event: Event, timeout: int): + if _IS_SYNC: + return event.wait(timeout) # type: ignore[call-arg] + else: + try: + asyncio.wait_for(event.wait(), timeout=timeout) + except asyncio.TimeoutError: + return False + return True + if __name__ == "__main__": unittest.main() diff --git a/test/test_logger.py b/test/test_logger.py index b3c8e6d176..a7d97927fa 100644 --- a/test/test_logger.py +++ b/test/test_logger.py @@ -14,7 +14,7 @@ from __future__ import annotations import os -from test import IntegrationTest, unittest +from test import IntegrationTest, client_context, unittest from unittest.mock import patch from bson import json_util @@ -96,6 +96,49 @@ def test_logging_without_listeners(self): c.db.test.insert_one({"x": "1"}) self.assertGreater(len(cm.records), 0) + @client_context.require_failCommand_fail_point + def test_logging_retry_read_attempts(self): + self.db.test.insert_one({"x": "1"}) + + with self.fail_point( + { + "mode": {"times": 1}, + "data": { + "failCommands": ["find"], + "errorCode": 10107, + "errorLabels": ["RetryableWriteError"], + }, + } + ): + with self.assertLogs("pymongo.command", level="DEBUG") as cm: + self.db.test.find_one({"x": "1"}) + + retry_messages = [ + r.getMessage() for r in cm.records if "Retrying read attempt" in r.getMessage() + ] + self.assertEqual(len(retry_messages), 1) + + @client_context.require_failCommand_fail_point + @client_context.require_retryable_writes + def test_logging_retry_write_attempts(self): + with self.fail_point( + { + "mode": {"times": 1}, + "data": { + "errorCode": 10107, + "errorLabels": ["RetryableWriteError"], + "failCommands": ["insert"], + }, + } + ): + with self.assertLogs("pymongo.command", level="DEBUG") as cm: + self.db.test.insert_one({"x": "1"}) + + retry_messages = [ + r.getMessage() for r in cm.records if "Retrying write attempt" in r.getMessage() + ] + self.assertEqual(len(retry_messages), 1) + if __name__ == "__main__": unittest.main() diff --git a/test/test_max_staleness.py b/test/test_max_staleness.py index 32d09ada9a..56e047fd4b 100644 --- a/test/test_max_staleness.py +++ b/test/test_max_staleness.py @@ -15,10 +15,12 @@ """Test maxStalenessSeconds support.""" from __future__ import annotations +import asyncio import os import sys import time import warnings +from pathlib import Path from pymongo import MongoClient from pymongo.operations import _Op @@ -31,11 +33,16 @@ from pymongo.errors import ConfigurationError from pymongo.server_selectors import writable_server_selector +_IS_SYNC = True + # Location of JSON test specifications. -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "max_staleness") +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "max_staleness") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "max_staleness") -class TestAllScenarios(create_selection_tests(_TEST_PATH)): # type: ignore +class TestAllScenarios(create_selection_tests(TEST_PATH)): # type: ignore pass diff --git a/test/test_mongos_load_balancing.py b/test/test_mongos_load_balancing.py index 7bc8225465..8c31854343 100644 --- a/test/test_mongos_load_balancing.py +++ b/test/test_mongos_load_balancing.py @@ -15,8 +15,10 @@ """Test MongoClient's mongos load balancing using a mock.""" from __future__ import annotations +import asyncio import sys import threading +from test.helpers import ConcurrentRunner from pymongo.operations import _Op @@ -24,20 +26,16 @@ from test import MockClientTest, client_context, connected, unittest from test.pymongo_mocks import MockClient -from test.utils import wait_until +from test.utils_shared import wait_until from pymongo.errors import AutoReconnect, InvalidOperation from pymongo.server_selectors import writable_server_selector from pymongo.topology_description import TOPOLOGY_TYPE +_IS_SYNC = True -@client_context.require_connection -@client_context.require_no_load_balancer -def setUpModule(): - pass - -class SimpleOp(threading.Thread): +class SimpleOp(ConcurrentRunner): def __init__(self, client): super().__init__() self.client = client @@ -48,15 +46,15 @@ def run(self): self.passed = True # No exception raised. -def do_simple_op(client, nthreads): - threads = [SimpleOp(client) for _ in range(nthreads)] - for t in threads: +def do_simple_op(client, ntasks): + tasks = [SimpleOp(client) for _ in range(ntasks)] + for t in tasks: t.start() - for t in threads: + for t in tasks: t.join() - for t in threads: + for t in tasks: assert t.passed @@ -68,6 +66,11 @@ def writable_addresses(topology): class TestMongosLoadBalancing(MockClientTest): + @client_context.require_connection + @client_context.require_no_load_balancer + def setUp(self): + super().setUp() + def mock_client(self, **kwargs): mock_client = MockClient( standalones=[], @@ -98,7 +101,7 @@ def test_lazy_connect(self): wait_until(lambda: len(client.nodes) == 3, "connect to all mongoses") def test_failover(self): - nthreads = 10 + ntasks = 10 client = connected(self.mock_client(localThresholdMS=0.001)) wait_until(lambda: len(client.nodes) == 3, "connect to all mongoses") @@ -118,14 +121,14 @@ def f(): passed.append(True) - threads = [threading.Thread(target=f) for _ in range(nthreads)] - for t in threads: + tasks = [ConcurrentRunner(target=f) for _ in range(ntasks)] + for t in tasks: t.start() - for t in threads: + for t in tasks: t.join() - self.assertEqual(nthreads, len(passed)) + self.assertEqual(ntasks, len(passed)) # Down host removed from list. self.assertEqual(2, len(client.nodes)) @@ -183,8 +186,11 @@ def test_load_balancing(self): client.mock_rtts["a:1"] = 0.045 # Discover only b is within latency window. + def predicate(): + return {("b", 2)} == writable_addresses(topology) + wait_until( - lambda: {("b", 2)} == writable_addresses(topology), + predicate, 'discover server "a" is too far', ) diff --git a/test/test_monitor.py b/test/test_monitor.py index a704f3d8cb..25620a99e8 100644 --- a/test/test_monitor.py +++ b/test/test_monitor.py @@ -15,6 +15,7 @@ """Test the monitor module.""" from __future__ import annotations +import asyncio import gc import subprocess import sys @@ -23,14 +24,16 @@ sys.path[0:0] = [""] -from test import IntegrationTest, connected, unittest +from test import IntegrationTest, client_context, connected, unittest from test.utils import ( - ServerAndTopologyEventListener, wait_until, ) +from test.utils_shared import ServerAndTopologyEventListener from pymongo.periodic_executor import _EXECUTORS +_IS_SYNC = True + def unregistered(ref): gc.collect() @@ -55,8 +58,8 @@ def create_client(self): return client def test_cleanup_executors_on_client_del(self): - with warnings.catch_warnings(): - warnings.simplefilter("ignore") + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") client = self.create_client() executors = get_executors(client) self.assertEqual(len(executors), 4) @@ -70,6 +73,19 @@ def test_cleanup_executors_on_client_del(self): for ref, name in executor_refs: wait_until(partial(unregistered, ref), f"unregister executor: {name}", timeout=5) + def resource_warning_caught(): + gc.collect() + for warning in w: + if ( + issubclass(warning.category, ResourceWarning) + and "Call MongoClient.close() to safely shut down your client and free up resources." + in str(warning.message) + ): + return True + return False + + wait_until(resource_warning_caught, "catch resource warning") + def test_cleanup_executors_on_client_close(self): client = self.create_client() executors = get_executors(client) @@ -80,10 +96,15 @@ def test_cleanup_executors_on_client_close(self): for executor in executors: wait_until(lambda: executor._stopped, f"closed executor: {executor._name}", timeout=5) + @client_context.require_sync def test_no_thread_start_runtime_err_on_shutdown(self): """Test we silence noisy runtime errors fired when the MongoClient spawns a new thread on process shutdown.""" - command = [sys.executable, "-c", "from pymongo import MongoClient; c = MongoClient()"] + command = [ + sys.executable, + "-c", + "from pymongo import MongoClient; c = MongoClient()", + ] completed_process: subprocess.CompletedProcess = subprocess.run( command, capture_output=True ) diff --git a/test/test_monitoring.py b/test/test_monitoring.py index 670558c0a0..ae3e50db77 100644 --- a/test/test_monitoring.py +++ b/test/test_monitoring.py @@ -29,7 +29,7 @@ sanitize_cmd, unittest, ) -from test.utils import ( +from test.utils_shared import ( EventListener, OvertCommandListener, wait_until, diff --git a/test/test_objectid.py b/test/test_objectid.py index 26670832f6..d7db7229ea 100644 --- a/test/test_objectid.py +++ b/test/test_objectid.py @@ -23,7 +23,7 @@ sys.path[0:0] = [""] from test import SkipTest, unittest -from test.utils import oid_generated_on_process +from test.utils_shared import oid_generated_on_process from bson.errors import InvalidId from bson.objectid import _MAX_COUNTER_VALUE, ObjectId diff --git a/test/test_on_demand_csfle.py b/test/test_on_demand_csfle.py index 023feca8c2..648e46815a 100644 --- a/test/test_on_demand_csfle.py +++ b/test/test_on_demand_csfle.py @@ -26,18 +26,20 @@ from test import IntegrationTest, client_context from bson.codec_options import CodecOptions -from pymongo.synchronous.encryption import _HAVE_PYMONGOCRYPT, ClientEncryption, EncryptionError +from pymongo.synchronous.encryption import ( + _HAVE_PYMONGOCRYPT, + ClientEncryption, + EncryptionError, +) -pytestmark = pytest.mark.csfle +_IS_SYNC = True + +pytestmark = pytest.mark.kms class TestonDemandGCPCredentials(IntegrationTest): - @classmethod @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") @client_context.require_version_min(4, 2, -1) - def setUpClass(cls): - super().setUpClass() - def setUp(self): super().setUp() self.master_key = { @@ -74,12 +76,8 @@ def test_02_success(self): class TestonDemandAzureCredentials(IntegrationTest): - @classmethod @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") @client_context.require_version_min(4, 2, -1) - def setUpClass(cls): - super().setUpClass() - def setUp(self): super().setUp() self.master_key = { diff --git a/test/test_pooling.py b/test/test_pooling.py index 3b867965bd..05513afe12 100644 --- a/test/test_pooling.py +++ b/test/test_pooling.py @@ -15,42 +15,42 @@ """Test built in connection-pooling with threads.""" from __future__ import annotations +import asyncio import gc import random import socket import sys -import threading import time +from test.utils import get_pool, joinall from bson.codec_options import DEFAULT_CODEC_OPTIONS from bson.son import SON from pymongo import MongoClient, message, timeout from pymongo.errors import AutoReconnect, ConnectionFailure, DuplicateKeyError from pymongo.hello import HelloCompat +from pymongo.lock import _create_lock sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import delay, get_pool, joinall +from test.helpers import ConcurrentRunner +from test.utils_shared import delay from pymongo.socket_checker import SocketChecker from pymongo.synchronous.pool import Pool, PoolOptions - -@client_context.require_connection -def setUpModule(): - pass +_IS_SYNC = True N = 10 DB = "pymongo-pooling-tests" -def gc_collect_until_done(threads, timeout=60): +def gc_collect_until_done(tasks, timeout=60): start = time.time() - running = list(threads) + running = list(tasks) while running: - assert (time.time() - start) < timeout, "Threads timed out" + assert (time.time() - start) < timeout, "Tasks timed out" for t in running: t.join(0.1) if not t.is_alive(): @@ -58,12 +58,12 @@ def gc_collect_until_done(threads, timeout=60): gc.collect() -class MongoThread(threading.Thread): - """A thread that uses a MongoClient.""" +class MongoTask(ConcurrentRunner): + """A thread/Task that uses a MongoClient.""" def __init__(self, client): super().__init__() - self.daemon = True # Don't hang whole test if thread hangs. + self.daemon = True # Don't hang whole test if task hangs. self.client = client self.db = self.client[DB] self.passed = False @@ -76,21 +76,21 @@ def run_mongo_thread(self): raise NotImplementedError -class InsertOneAndFind(MongoThread): +class InsertOneAndFind(MongoTask): def run_mongo_thread(self): for _ in range(N): rand = random.randint(0, N) - _id = self.db.sf.insert_one({"x": rand}).inserted_id - assert rand == self.db.sf.find_one(_id)["x"] + _id = (self.db.sf.insert_one({"x": rand})).inserted_id + assert rand == (self.db.sf.find_one(_id))["x"] -class Unique(MongoThread): +class Unique(MongoTask): def run_mongo_thread(self): for _ in range(N): self.db.unique.insert_one({}) # no error -class NonUnique(MongoThread): +class NonUnique(MongoTask): def run_mongo_thread(self): for _ in range(N): try: @@ -101,7 +101,7 @@ def run_mongo_thread(self): raise AssertionError("Should have raised DuplicateKeyError") -class SocketGetter(MongoThread): +class SocketGetter(MongoTask): """Utility for TestPooling. Checks out a socket and holds it forever. Used in @@ -124,31 +124,35 @@ def run_mongo_thread(self): self.state = "connection" - def __del__(self): + def release_conn(self): if self.sock: - self.sock.close_conn(None) + self.sock.unpin() + self.sock = None + return True + return False def run_cases(client, cases): - threads = [] + tasks = [] n_runs = 5 for case in cases: for _i in range(n_runs): t = case(client) t.start() - threads.append(t) + tasks.append(t) - for t in threads: + for t in tasks: t.join() - for t in threads: + for t in tasks: assert t.passed, "%s.run() threw an exception" % repr(t) class _TestPoolingBase(IntegrationTest): """Base class for all connection-pool tests.""" + @client_context.require_connection def setUp(self): super().setUp() self.c = self.rs_or_single_client() @@ -158,11 +162,9 @@ def setUp(self): db.unique.insert_one({"_id": "jesse"}) db.test.insert_many([{} for _ in range(10)]) - def tearDown(self): - self.c.close() - super().tearDown() - - def create_pool(self, pair=(client_context.host, client_context.port), *args, **kwargs): + def create_pool(self, pair=None, *args, **kwargs): + if pair is None: + pair = (client_context.host, client_context.port) # Start the pool with the correct ssl options. pool_options = client_context.client._topology_settings.pool_options kwargs["ssl_context"] = pool_options._ssl_context @@ -354,6 +356,10 @@ def test_no_wait_queue_timeout(self): self.assertEqual(t.state, "connection") self.assertEqual(t.sock, s1) + # Cleanup + t.release_conn() + t.join() + pool.close() def test_checkout_more_than_max_pool_size(self): pool = self.create_pool(max_pool_size=2) @@ -365,21 +371,30 @@ def test_checkout_more_than_max_pool_size(self): sock.pin_cursor() socks.append(sock) - threads = [] - for _ in range(30): + tasks = [] + for _ in range(10): t = SocketGetter(self.c, pool) t.start() - threads.append(t) + tasks.append(t) time.sleep(1) - for t in threads: + for t in tasks: self.assertEqual(t.state, "get_socket") - + # Cleanup for socket_info in socks: - socket_info.close_conn(None) + socket_info.unpin() + while tasks: + to_remove = [] + for t in tasks: + if t.release_conn(): + to_remove.append(t) + t.join() + for t in to_remove: + tasks.remove(t) + time.sleep(0.05) + pool.close() def test_maxConnecting(self): client = self.rs_or_single_client() - self.addCleanup(client.close) self.client.test.test.insert_one({}) self.addCleanup(self.client.test.test.delete_many, {}) pool = get_pool(client) @@ -389,11 +404,11 @@ def test_maxConnecting(self): def find_one(): docs.append(client.test.test.find_one({})) - threads = [threading.Thread(target=find_one) for _ in range(50)] - for thread in threads: - thread.start() - for thread in threads: - thread.join(10) + tasks = [ConcurrentRunner(target=find_one) for _ in range(50)] + for task in tasks: + task.start() + for task in tasks: + task.join(10) self.assertEqual(len(docs), 50) self.assertLessEqual(len(pool.conns), 50) @@ -416,7 +431,6 @@ def find_one(): @client_context.require_failCommand_appName def test_csot_timeout_message(self): client = self.rs_or_single_client(appName="connectionTimeoutApp") - self.addCleanup(client.close) # Mock an operation failing due to pymongo.timeout(). mock_connection_timeout = { "configureFailPoint": "failCommand", @@ -436,12 +450,11 @@ def test_csot_timeout_message(self): with timeout(0.5): client.db.t.find_one({"$where": delay(2)}) - self.assertTrue("(configured timeouts: timeoutMS: 500.0ms" in str(error.exception)) + self.assertIn("(configured timeouts: timeoutMS: 500.0ms", str(error.exception)) @client_context.require_failCommand_appName def test_socket_timeout_message(self): client = self.rs_or_single_client(socketTimeoutMS=500, appName="connectionTimeoutApp") - self.addCleanup(client.close) # Mock an operation failing due to socketTimeoutMS. mock_connection_timeout = { "configureFailPoint": "failCommand", @@ -460,9 +473,9 @@ def test_socket_timeout_message(self): with self.assertRaises(Exception) as error: client.db.t.find_one({"$where": delay(2)}) - self.assertTrue( - "(configured timeouts: socketTimeoutMS: 500.0ms, connectTimeoutMS: 20000.0ms)" - in str(error.exception) + self.assertIn( + "(configured timeouts: socketTimeoutMS: 500.0ms, connectTimeoutMS: 20000.0ms)", + str(error.exception), ) @client_context.require_failCommand_appName @@ -485,7 +498,6 @@ def test_connection_timeout_message(self): appName="connectionTimeoutApp", heartbeatFrequencyMS=1000000, ) - self.addCleanup(client.close) client.admin.command("ping") pool = get_pool(client) pool.reset_without_pause() @@ -493,9 +505,9 @@ def test_connection_timeout_message(self): with self.assertRaises(Exception) as error: client.admin.command("ping") - self.assertTrue( - "(configured timeouts: socketTimeoutMS: 500.0ms, connectTimeoutMS: 500.0ms)" - in str(error.exception) + self.assertIn( + "(configured timeouts: socketTimeoutMS: 500.0ms, connectTimeoutMS: 500.0ms)", + str(error.exception), ) @@ -503,20 +515,19 @@ class TestPoolMaxSize(_TestPoolingBase): def test_max_pool_size(self): max_pool_size = 4 c = self.rs_or_single_client(maxPoolSize=max_pool_size) - self.addCleanup(c.close) collection = c[DB].test # Need one document. collection.drop() collection.insert_one({}) - # nthreads had better be much larger than max_pool_size to ensure that + # ntasks had better be much larger than max_pool_size to ensure that # max_pool_size connections are actually required at some point in this # test's execution. cx_pool = get_pool(c) - nthreads = 10 - threads = [] - lock = threading.Lock() + ntasks = 10 + tasks = [] + lock = _create_lock() self.n_passed = 0 def f(): @@ -527,19 +538,18 @@ def f(): with lock: self.n_passed += 1 - for _i in range(nthreads): - t = threading.Thread(target=f) - threads.append(t) + for _i in range(ntasks): + t = ConcurrentRunner(target=f) + tasks.append(t) t.start() - joinall(threads) - self.assertEqual(nthreads, self.n_passed) + joinall(tasks) + self.assertEqual(ntasks, self.n_passed) self.assertTrue(len(cx_pool.conns) > 1) self.assertEqual(0, cx_pool.requests) def test_max_pool_size_none(self): c = self.rs_or_single_client(maxPoolSize=None) - self.addCleanup(c.close) collection = c[DB].test # Need one document. @@ -547,9 +557,9 @@ def test_max_pool_size_none(self): collection.insert_one({}) cx_pool = get_pool(c) - nthreads = 10 - threads = [] - lock = threading.Lock() + ntasks = 10 + tasks = [] + lock = _create_lock() self.n_passed = 0 def f(): @@ -559,19 +569,18 @@ def f(): with lock: self.n_passed += 1 - for _i in range(nthreads): - t = threading.Thread(target=f) - threads.append(t) + for _i in range(ntasks): + t = ConcurrentRunner(target=f) + tasks.append(t) t.start() - joinall(threads) - self.assertEqual(nthreads, self.n_passed) + joinall(tasks) + self.assertEqual(ntasks, self.n_passed) self.assertTrue(len(cx_pool.conns) > 1) self.assertEqual(cx_pool.max_pool_size, float("inf")) def test_max_pool_size_zero(self): c = self.rs_or_single_client(maxPoolSize=0) - self.addCleanup(c.close) pool = get_pool(c) self.assertEqual(pool.max_pool_size, float("inf")) diff --git a/test/test_read_concern.py b/test/test_read_concern.py index f7c0901422..62b2491475 100644 --- a/test/test_read_concern.py +++ b/test/test_read_concern.py @@ -21,12 +21,14 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context -from test.utils import OvertCommandListener +from test.utils_shared import OvertCommandListener from bson.son import SON from pymongo.errors import OperationFailure from pymongo.read_concern import ReadConcern +_IS_SYNC = True + class TestReadConcern(IntegrationTest): listener: OvertCommandListener @@ -71,14 +73,14 @@ def test_invalid_read_concern(self): def test_find_command(self): # readConcern not sent in command if not specified. coll = self.db.coll - tuple(coll.find({"field": "value"})) + coll.find({"field": "value"}).to_list() self.assertNotIn("readConcern", self.listener.started_events[0].command) self.listener.reset() # Explicitly set readConcern to 'local'. coll = self.db.get_collection("coll", read_concern=ReadConcern("local")) - tuple(coll.find({"field": "value"})) + coll.find({"field": "value"}).to_list() self.assertEqualCommand( SON( [ @@ -93,19 +95,19 @@ def test_find_command(self): def test_command_cursor(self): # readConcern not sent in command if not specified. coll = self.db.coll - tuple(coll.aggregate([{"$match": {"field": "value"}}])) + (coll.aggregate([{"$match": {"field": "value"}}])).to_list() self.assertNotIn("readConcern", self.listener.started_events[0].command) self.listener.reset() # Explicitly set readConcern to 'local'. coll = self.db.get_collection("coll", read_concern=ReadConcern("local")) - tuple(coll.aggregate([{"$match": {"field": "value"}}])) + (coll.aggregate([{"$match": {"field": "value"}}])).to_list() self.assertEqual({"level": "local"}, self.listener.started_events[0].command["readConcern"]) def test_aggregate_out(self): coll = self.db.get_collection("coll", read_concern=ReadConcern("local")) - tuple(coll.aggregate([{"$match": {"field": "value"}}, {"$out": "output_collection"}])) + (coll.aggregate([{"$match": {"field": "value"}}, {"$out": "output_collection"}])).to_list() # Aggregate with $out supports readConcern MongoDB 4.2 onwards. if client_context.version >= (4, 1): diff --git a/test/test_read_preferences.py b/test/test_read_preferences.py index 32883399e1..afde01723d 100644 --- a/test/test_read_preferences.py +++ b/test/test_read_preferences.py @@ -26,9 +26,16 @@ sys.path[0:0] = [""] -from test import IntegrationTest, SkipTest, client_context, connected, unittest -from test.utils import ( +from test import ( + IntegrationTest, + SkipTest, + client_context, + connected, + unittest, +) +from test.utils_shared import ( OvertCommandListener, + _ignore_deprecations, one, wait_until, ) @@ -49,16 +56,22 @@ from pymongo.server_description import ServerDescription from pymongo.server_selectors import Selection, readable_server_selector from pymongo.server_type import SERVER_TYPE +from pymongo.synchronous.helpers import next from pymongo.synchronous.mongo_client import MongoClient from pymongo.write_concern import WriteConcern +_IS_SYNC = True + class TestSelections(IntegrationTest): @client_context.require_connection def test_bool(self): client = self.single_client() - wait_until(lambda: client.address, "discover primary") + def predicate(): + return client.address + + wait_until(predicate, "discover primary") selection = Selection.from_topology_description(client._topology.description) self.assertTrue(selection) @@ -88,11 +101,7 @@ def test_deepcopy(self): class TestReadPreferencesBase(IntegrationTest): - @classmethod @client_context.require_secondaries_count(1) - def setUpClass(cls): - super().setUpClass() - def setUp(self): super().setUp() # Insert some data so we can use cursors in read_from_which_host @@ -123,11 +132,14 @@ def read_from_which_kind(self, client): f"Cursor used address {address}, expected either primary " f"{client.primary} or secondaries {client.secondaries}" ) - return None def assertReadsFrom(self, expected, **kwargs): c = self.rs_client(**kwargs) - wait_until(lambda: len(c.nodes - c.arbiters) == client_context.w, "discovered all nodes") + + def predicate(): + return len(c.nodes - c.arbiters) == client_context.w + + wait_until(predicate, "discovered all nodes") used = self.read_from_which_kind(c) self.assertEqual(expected, used, f"Cursor used {used}, expected {expected}") @@ -150,7 +162,7 @@ def test_reads_from_secondary(self): # Test find and find_one. self.assertIsNotNone(coll.find_one()) - self.assertEqual(10, len(list(coll.find()))) + self.assertEqual(10, len(coll.find().to_list())) # Test some database helpers. self.assertIsNotNone(db.list_collection_names()) @@ -173,20 +185,22 @@ def test_mode_validation(self): ReadPreference.SECONDARY_PREFERRED, ReadPreference.NEAREST, ): - self.assertEqual(mode, self.rs_client(read_preference=mode).read_preference) + self.assertEqual(mode, (self.rs_client(read_preference=mode)).read_preference) - self.assertRaises(TypeError, self.rs_client, read_preference="foo") + with self.assertRaises(TypeError): + self.rs_client(read_preference="foo") def test_tag_sets_validation(self): S = Secondary(tag_sets=[{}]) - self.assertEqual([{}], self.rs_client(read_preference=S).read_preference.tag_sets) + self.assertEqual([{}], (self.rs_client(read_preference=S)).read_preference.tag_sets) S = Secondary(tag_sets=[{"k": "v"}]) - self.assertEqual([{"k": "v"}], self.rs_client(read_preference=S).read_preference.tag_sets) + self.assertEqual([{"k": "v"}], (self.rs_client(read_preference=S)).read_preference.tag_sets) S = Secondary(tag_sets=[{"k": "v"}, {}]) self.assertEqual( - [{"k": "v"}, {}], self.rs_client(read_preference=S).read_preference.tag_sets + [{"k": "v"}, {}], + (self.rs_client(read_preference=S)).read_preference.tag_sets, ) self.assertRaises(ValueError, Secondary, tag_sets=[]) @@ -200,22 +214,27 @@ def test_tag_sets_validation(self): def test_threshold_validation(self): self.assertEqual( - 17, self.rs_client(localThresholdMS=17, connect=False).options.local_threshold_ms + 17, + (self.rs_client(localThresholdMS=17, connect=False)).options.local_threshold_ms, ) self.assertEqual( - 42, self.rs_client(localThresholdMS=42, connect=False).options.local_threshold_ms + 42, + (self.rs_client(localThresholdMS=42, connect=False)).options.local_threshold_ms, ) self.assertEqual( - 666, self.rs_client(localThresholdMS=666, connect=False).options.local_threshold_ms + 666, + (self.rs_client(localThresholdMS=666, connect=False)).options.local_threshold_ms, ) self.assertEqual( - 0, self.rs_client(localThresholdMS=0, connect=False).options.local_threshold_ms + 0, + (self.rs_client(localThresholdMS=0, connect=False)).options.local_threshold_ms, ) - self.assertRaises(ValueError, self.rs_client, localthresholdms=-1) + with self.assertRaises(ValueError): + self.rs_client(localthresholdms=-1) def test_zero_latency(self): ping_times: set = set() @@ -238,7 +257,8 @@ def test_primary(self): def test_primary_with_tags(self): # Tags not allowed with PRIMARY - self.assertRaises(ConfigurationError, self.rs_client, tag_sets=[{"dc": "ny"}]) + with self.assertRaises(ConfigurationError): + self.rs_client(tag_sets=[{"dc": "ny"}]) def test_primary_preferred(self): self.assertReadsFrom("primary", read_preference=ReadPreference.PRIMARY_PREFERRED) @@ -272,7 +292,7 @@ def test_nearest(self): not_used = data_members.difference(used) latencies = ", ".join( "%s: %sms" % (server.description.address, server.description.round_trip_time) - for server in c._get_topology().select_servers(readable_server_selector, _Op.TEST) + for server in (c._get_topology()).select_servers(readable_server_selector, _Op.TEST) ) self.assertFalse( @@ -289,12 +309,9 @@ def __init__(self, *args, **kwargs): client_options.update(kwargs) super().__init__(*args, **client_options) - @contextlib.contextmanager def _conn_for_reads(self, read_preference, session, operation): context = super()._conn_for_reads(read_preference, session, operation) - with context as (conn, read_preference): - self.record_a_read(conn.address) - yield conn, read_preference + return context @contextlib.contextmanager def _conn_from_server(self, read_preference, server, session): @@ -304,7 +321,7 @@ def _conn_from_server(self, read_preference, server, session): yield conn, read_preference def record_a_read(self, address): - server = self._get_topology().select_server_by_address(address, _Op.TEST, 0) + server = (self._get_topology()).select_server_by_address(address, _Op.TEST, 0) self.has_read_from.add(server) @@ -321,25 +338,23 @@ class TestCommandAndReadPreference(IntegrationTest): c: ReadPrefTester client_version: Version - @classmethod @client_context.require_secondaries_count(1) - def setUpClass(cls): - super().setUpClass() - cls.c = ReadPrefTester( + def setUp(self): + super().setUp() + self.c = ReadPrefTester( # Ignore round trip times, to test ReadPreference modes only. localThresholdMS=1000 * 1000, ) - cls.client_version = Version.from_client(cls.c) + self.client_version = Version.from_client(self.c) # mapReduce fails if the collection does not exist. - coll = cls.c.pymongo_test.get_collection( + coll = self.c.pymongo_test.get_collection( "test", write_concern=WriteConcern(w=client_context.w) ) coll.insert_one({}) - @classmethod - def tearDownClass(cls): - cls.c.drop_database("pymongo_test") - cls.c.close() + def tearDown(self): + self.c.drop_database("pymongo_test") + self.c.close() def executed_on_which_server(self, client, fn, *args, **kwargs): """Execute fn(*args, **kwargs) and return the Server instance used.""" @@ -366,7 +381,7 @@ def _test_fn(self, server_type, fn): break assert self.c.primary is not None - unused = self.c.secondaries.union({self.c.primary}).difference(used) + unused = (self.c.secondaries).union({self.c.primary}).difference(used) if unused: self.fail("Some members not used for NEAREST: %s" % (unused)) else: @@ -401,11 +416,12 @@ def func(): def test_create_collection(self): # create_collection runs listCollections on the primary to check if # the collection already exists. - self._test_primary_helper( - lambda: self.c.pymongo_test.create_collection( + def func(): + return self.c.pymongo_test.create_collection( "some_collection%s" % random.randint(0, sys.maxsize) ) - ) + + self._test_primary_helper(func) def test_count_documents(self): self._test_coll_helper(True, self.c.pymongo_test.test, "count_documents", {}) @@ -507,33 +523,44 @@ def test_read_preference_document_hedge(self): for mode, cls in cases.items(): with self.assertRaises(TypeError): cls(hedge=[]) # type: ignore - - pref = cls(hedge={}) - self.assertEqual(pref.document, {"mode": mode}) - out = _maybe_add_read_preference({}, pref) - if cls == SecondaryPreferred: - # SecondaryPreferred without hedge doesn't add $readPreference. - self.assertEqual(out, {}) - else: + with _ignore_deprecations(): + pref = cls(hedge={}) + self.assertEqual(pref.document, {"mode": mode}) + out = _maybe_add_read_preference({}, pref) + if cls == SecondaryPreferred: + # SecondaryPreferred without hedge doesn't add $readPreference. + self.assertEqual(out, {}) + else: + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + + hedge: dict[str, Any] = {"enabled": True} + pref = cls(hedge=hedge) + self.assertEqual(pref.document, {"mode": mode, "hedge": hedge}) + out = _maybe_add_read_preference({}, pref) self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) - hedge: dict[str, Any] = {"enabled": True} - pref = cls(hedge=hedge) - self.assertEqual(pref.document, {"mode": mode, "hedge": hedge}) - out = _maybe_add_read_preference({}, pref) - self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + hedge = {"enabled": False} + pref = cls(hedge=hedge) + self.assertEqual(pref.document, {"mode": mode, "hedge": hedge}) + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) - hedge = {"enabled": False} - pref = cls(hedge=hedge) - self.assertEqual(pref.document, {"mode": mode, "hedge": hedge}) - out = _maybe_add_read_preference({}, pref) - self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + hedge = {"enabled": False, "extra": "option"} + pref = cls(hedge=hedge) + self.assertEqual(pref.document, {"mode": mode, "hedge": hedge}) + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) - hedge = {"enabled": False, "extra": "option"} - pref = cls(hedge=hedge) - self.assertEqual(pref.document, {"mode": mode, "hedge": hedge}) - out = _maybe_add_read_preference({}, pref) - self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + def test_read_preference_hedge_deprecated(self): + cases = { + "primaryPreferred": PrimaryPreferred, + "secondary": Secondary, + "secondaryPreferred": SecondaryPreferred, + "nearest": Nearest, + } + for _, cls in cases.items(): + with self.assertRaises(DeprecationWarning): + cls(hedge={"enabled": True}) def test_send_hedge(self): cases = { @@ -545,10 +572,10 @@ def test_send_hedge(self): cases["secondary"] = Secondary listener = OvertCommandListener() client = self.rs_client(event_listeners=[listener]) - self.addCleanup(client.close) client.admin.command("ping") for _mode, cls in cases.items(): - pref = cls(hedge={"enabled": True}) + with _ignore_deprecations(): + pref = cls(hedge={"enabled": True}) coll = client.test.get_collection("test", read_preference=pref) listener.reset() coll.find_one() @@ -645,10 +672,10 @@ def test_mongos(self): # tell what shard member a query ran on. for pref in (Primary(), PrimaryPreferred(), Secondary(), SecondaryPreferred(), Nearest()): qcoll = coll.with_options(read_preference=pref) - results = list(qcoll.find().sort([("_id", 1)])) + results = qcoll.find().sort([("_id", 1)]).to_list() self.assertEqual(first_id, results[0]["_id"]) self.assertEqual(last_id, results[-1]["_id"]) - results = list(qcoll.find().sort([("_id", -1)])) + results = qcoll.find().sort([("_id", -1)]).to_list() self.assertEqual(first_id, results[-1]["_id"]) self.assertEqual(last_id, results[0]["_id"]) @@ -671,14 +698,14 @@ def test_mongos_max_staleness(self): else: self.fail("mongos accepted invalid staleness") - coll = self.single_client( - readPreference="secondaryPreferred", maxStalenessSeconds=120 + coll = ( + self.single_client(readPreference="secondaryPreferred", maxStalenessSeconds=120) ).pymongo_test.test # No error coll.find_one() - coll = self.single_client( - readPreference="secondaryPreferred", maxStalenessSeconds=10 + coll = ( + self.single_client(readPreference="secondaryPreferred", maxStalenessSeconds=10) ).pymongo_test.test try: coll.find_one() diff --git a/test/test_read_write_concern_spec.py b/test/test_read_write_concern_spec.py index db53b67ae4..383dc70902 100644 --- a/test/test_read_write_concern_spec.py +++ b/test/test_read_write_concern_spec.py @@ -19,12 +19,13 @@ import os import sys import warnings +from pathlib import Path sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest from test.unified_format import generate_test_classes -from test.utils import OvertCommandListener +from test.utils_shared import OvertCommandListener from pymongo import DESCENDING from pymongo.errors import ( @@ -39,7 +40,13 @@ from pymongo.synchronous.mongo_client import MongoClient from pymongo.write_concern import WriteConcern -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "read_write_concern") +_IS_SYNC = True + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "read_write_concern") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "read_write_concern") class TestReadWriteConcernSpec(IntegrationTest): @@ -47,7 +54,6 @@ def test_omit_default_read_write_concern(self): listener = OvertCommandListener() # Client with default readConcern and writeConcern client = self.rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) collection = client.pymongo_test.collection # Prepare for tests of find() and aggregate(). collection.insert_many([{} for _ in range(10)]) @@ -66,9 +72,12 @@ def insert_command_default_write_concern(): "insert", "collection", documents=[{}], write_concern=WriteConcern() ) + def aggregate_op(): + (collection.aggregate([])).to_list() + ops = [ - ("aggregate", lambda: list(collection.aggregate([]))), - ("find", lambda: list(collection.find())), + ("aggregate", aggregate_op), + ("find", lambda: collection.find().to_list()), ("insert_one", lambda: collection.insert_one({})), ("update_one", lambda: collection.update_one({}, {"$set": {"x": 1}})), ("update_many", lambda: collection.update_many({}, {"$set": {"x": 1}})), @@ -207,7 +216,6 @@ def test_error_includes_errInfo(self): def test_write_error_details_exposes_errinfo(self): listener = OvertCommandListener() client = self.rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) db = client.errinfotest self.addCleanup(client.drop_database, "errinfotest") validator = {"x": {"$type": "string"}} @@ -286,7 +294,7 @@ def run_test(self): def create_tests(): - for dirpath, _, filenames in os.walk(_TEST_PATH): + for dirpath, _, filenames in os.walk(TEST_PATH): dirname = os.path.split(dirpath)[-1] if dirname == "operation": @@ -321,7 +329,7 @@ def create_tests(): # PyMongo does not support MapReduce. globals().update( generate_test_classes( - os.path.join(_TEST_PATH, "operation"), + os.path.join(TEST_PATH, "operation"), module=__name__, expected_failures=["MapReduce .*"], ) diff --git a/test/test_replica_set_reconfig.py b/test/test_replica_set_reconfig.py index 4c23d71b69..3371543f27 100644 --- a/test/test_replica_set_reconfig.py +++ b/test/test_replica_set_reconfig.py @@ -21,7 +21,7 @@ from test import MockClientTest, client_context, client_knobs, unittest from test.pymongo_mocks import MockClient -from test.utils import wait_until +from test.utils_shared import wait_until from pymongo import ReadPreference from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError diff --git a/test/test_retryable_reads.py b/test/test_retryable_reads.py index 9c3f6b170f..7ae4c41e70 100644 --- a/test/test_retryable_reads.py +++ b/test/test_retryable_reads.py @@ -19,6 +19,7 @@ import pprint import sys import threading +from test.utils import set_fail_point from pymongo.errors import AutoReconnect @@ -31,10 +32,9 @@ client_knobs, unittest, ) -from test.utils import ( +from test.utils_shared import ( CMAPListener, OvertCommandListener, - set_fail_point, ) from pymongo.monitoring import ( diff --git a/test/test_retryable_reads_unified.py b/test/test_retryable_reads_unified.py index 3f8740cf4b..b1c6435c9a 100644 --- a/test/test_retryable_reads_unified.py +++ b/test/test_retryable_reads_unified.py @@ -15,6 +15,7 @@ """Test the Retryable Reads unified spec tests.""" from __future__ import annotations +import os import sys from pathlib import Path @@ -23,8 +24,13 @@ from test import unittest from test.unified_format import generate_test_classes +_IS_SYNC = True + # Location of JSON test specifications. -TEST_PATH = Path(__file__).parent / "retryable_reads/unified" +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "retryable_reads/unified") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "retryable_reads/unified") # Generate unified tests. # PyMongo does not support MapReduce, ListDatabaseObjects or ListCollectionObjects. diff --git a/test/test_retryable_writes.py b/test/test_retryable_writes.py index 07bd1db0ba..598fc3fd76 100644 --- a/test/test_retryable_writes.py +++ b/test/test_retryable_writes.py @@ -20,6 +20,7 @@ import pprint import sys import threading +from test.utils import set_fail_point sys.path[0:0] = [""] @@ -30,12 +31,11 @@ unittest, ) from test.helpers import client_knobs -from test.utils import ( +from test.utils_shared import ( CMAPListener, DeprecationFilter, EventListener, OvertCommandListener, - set_fail_point, ) from test.version import Version @@ -137,6 +137,7 @@ def setUp(self) -> None: self.deprecation_filter = DeprecationFilter() def tearDown(self) -> None: + super().tearDown() self.deprecation_filter.stop() @@ -194,6 +195,7 @@ def tearDown(self): SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")]) ) self.knobs.disable() + super().tearDown() def test_supported_single_statement_no_retry(self): listener = OvertCommandListener() diff --git a/test/test_retryable_writes_unified.py b/test/test_retryable_writes_unified.py index da16166ec6..036c410e24 100644 --- a/test/test_retryable_writes_unified.py +++ b/test/test_retryable_writes_unified.py @@ -17,14 +17,20 @@ import os import sys +from pathlib import Path sys.path[0:0] = [""] from test import unittest from test.unified_format import generate_test_classes +_IS_SYNC = True + # Location of JSON test specifications. -TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "retryable_writes", "unified") +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "retryable_writes/unified") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "retryable_writes/unified") # Generate unified tests. globals().update(generate_test_classes(TEST_PATH, module=__name__)) diff --git a/test/test_run_command.py b/test/test_run_command.py index 486a4c7e39..d2ef43b97e 100644 --- a/test/test_run_command.py +++ b/test/test_run_command.py @@ -1,15 +1,37 @@ +# Copyright 2024-Present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run Command unified tests.""" from __future__ import annotations import os import unittest +from pathlib import Path from test.unified_format import generate_test_classes -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "run_command") +_IS_SYNC = True + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "run_command") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "run_command") globals().update( generate_test_classes( - os.path.join(_TEST_PATH, "unified"), + os.path.join(TEST_PATH, "unified"), module=__name__, ) ) diff --git a/test/test_sdam_monitoring_spec.py b/test/test_sdam_monitoring_spec.py index 6b808b159d..2167e561cf 100644 --- a/test/test_sdam_monitoring_spec.py +++ b/test/test_sdam_monitoring_spec.py @@ -15,15 +15,17 @@ """Run the sdam monitoring spec tests.""" from __future__ import annotations +import asyncio import json import os import sys import time +from pathlib import Path sys.path[0:0] = [""] from test import IntegrationTest, client_context, client_knobs, unittest -from test.utils import ( +from test.utils_shared import ( ServerAndTopologyEventListener, server_name_to_type, wait_until, @@ -39,8 +41,13 @@ from pymongo.synchronous.monitor import Monitor from pymongo.topology_description import TOPOLOGY_TYPE +_IS_SYNC = True + # Location of JSON test specifications. -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sdam_monitoring") +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "sdam_monitoring") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "sdam_monitoring") def compare_server_descriptions(expected, actual): @@ -247,7 +254,7 @@ def _run(self): def create_tests(): - for dirpath, _, filenames in os.walk(_TEST_PATH): + for dirpath, _, filenames in os.walk(TEST_PATH): for filename in filenames: with open(os.path.join(dirpath, filename)) as scenario_stream: scenario_def = json.load(scenario_stream, object_hook=object_hook) @@ -268,31 +275,33 @@ class TestSdamMonitoring(IntegrationTest): coll: Collection @classmethod - @client_context.require_failCommand_fail_point def setUpClass(cls): - super().setUp(cls) # Speed up the tests by decreasing the event publish frequency. cls.knobs = client_knobs( events_queue_frequency=0.1, heartbeat_frequency=0.1, min_heartbeat_interval=0.1 ) cls.knobs.enable() cls.listener = ServerAndTopologyEventListener() - retry_writes = client_context.supports_transactions() - cls.test_client = cls.unmanaged_rs_or_single_client( - event_listeners=[cls.listener], retryWrites=retry_writes - ) - cls.coll = cls.test_client[cls.client.db.name].test - cls.coll.insert_one({}) @classmethod def tearDownClass(cls): - cls.test_client.close() cls.knobs.disable() - super().tearDownClass() + @client_context.require_failCommand_fail_point def setUp(self): + super().setUp() + + retry_writes = client_context.supports_transactions() + self.test_client = self.rs_or_single_client( + event_listeners=[self.listener], retryWrites=retry_writes + ) + self.coll = self.test_client[self.client.db.name].test + self.coll.insert_one({}) self.listener.reset() + def tearDown(self): + super().tearDown() + def _test_app_error(self, fail_command_opts, expected_error): address = self.test_client.address @@ -334,7 +343,7 @@ def marked_unknown_and_rediscovered(): and len(self.listener.matching(discovered_node)) >= 1 ) - # Topology events are published asynchronously + # Topology events are not published synchronously wait_until(marked_unknown_and_rediscovered, "rediscover node") # Expect a single ServerDescriptionChangedEvent for the network error. diff --git a/test/test_server.py b/test/test_server.py index 45d01c10de..ab5a40a79b 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -31,7 +31,7 @@ def test_repr(self): hello = Hello({"ok": 1}) sd = ServerDescription(("localhost", 27017), hello) server = Server(sd, pool=object(), monitor=object()) # type: ignore[arg-type] - self.assertTrue("Standalone" in str(server)) + self.assertIn("Standalone", str(server)) if __name__ == "__main__": diff --git a/test/test_server_selection.py b/test/test_server_selection.py index 984b967f50..aec8e2e47a 100644 --- a/test/test_server_selection.py +++ b/test/test_server_selection.py @@ -17,6 +17,7 @@ import os import sys +from pathlib import Path from pymongo import MongoClient, ReadPreference from pymongo.errors import ServerSelectionTimeoutError @@ -30,24 +31,31 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import ( - EventListener, - FunctionCallRecorder, - OvertCommandListener, - wait_until, -) +from test.utils import wait_until from test.utils_selection_tests import ( create_selection_tests, - get_addresses, get_topology_settings_dict, +) +from test.utils_selection_tests_shared import ( + get_addresses, make_server_description, ) +from test.utils_shared import ( + FunctionCallRecorder, + OvertCommandListener, +) + +_IS_SYNC = True # Location of JSON test specifications. -_TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), - os.path.join("server_selection", "server_selection"), -) +if _IS_SYNC: + TEST_PATH = os.path.join( + Path(__file__).resolve().parent, "server_selection", "server_selection" + ) +else: + TEST_PATH = os.path.join( + Path(__file__).resolve().parent.parent, "server_selection", "server_selection" + ) class SelectionStoreSelector: @@ -61,7 +69,7 @@ def __call__(self, selection): return selection -class TestAllScenarios(create_selection_tests(_TEST_PATH)): # type: ignore +class TestAllScenarios(create_selection_tests(TEST_PATH)): # type: ignore pass @@ -79,13 +87,12 @@ def custom_selector(servers): client = self.rs_or_single_client( server_selector=custom_selector, event_listeners=[listener] ) - self.addCleanup(client.close) coll = client.get_database("testdb", read_preference=ReadPreference.NEAREST).coll self.addCleanup(client.drop_database, "testdb") # Wait the node list to be fully populated. def all_hosts_started(): - return len(client.admin.command(HelloCompat.LEGACY_CMD)["hosts"]) == len( + return len((client.admin.command(HelloCompat.LEGACY_CMD))["hosts"]) == len( client._topology._description.readable_servers ) @@ -121,7 +128,6 @@ def test_selector_called(self): # Client setup. mongo_client = self.rs_or_single_client(server_selector=selector) test_collection = mongo_client.testdb.test_collection - self.addCleanup(mongo_client.close) self.addCleanup(mongo_client.drop_database, "testdb") # Do N operations and test selector is called at least N times. diff --git a/test/test_server_selection_in_window.py b/test/test_server_selection_in_window.py index 05772fa385..4aad34050c 100644 --- a/test/test_server_selection_in_window.py +++ b/test/test_server_selection_in_window.py @@ -15,16 +15,18 @@ """Test the topology module's Server Selection Spec implementation.""" from __future__ import annotations +import asyncio import os import threading +from pathlib import Path from test import IntegrationTest, client_context, unittest -from test.utils import ( +from test.helpers import ConcurrentRunner +from test.utils_selection_tests import create_topology +from test.utils_shared import ( CMAPListener, OvertCommandListener, - get_pool, wait_until, ) -from test.utils_selection_tests import create_topology from test.utils_spec_runner import SpecTestCreator from pymongo.common import clean_node @@ -32,10 +34,14 @@ from pymongo.operations import _Op from pymongo.read_preferences import ReadPreference +_IS_SYNC = True # Location of JSON test specifications. -TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), os.path.join("server_selection", "in_window") -) +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection", "in_window") +else: + TEST_PATH = os.path.join( + Path(__file__).resolve().parent.parent, "server_selection", "in_window" + ) class TestAllScenarios(unittest.TestCase): @@ -92,7 +98,7 @@ def tests(self, scenario_def): CustomSpecTestCreator(create_test, TestAllScenarios, TEST_PATH).create_tests() -class FinderThread(threading.Thread): +class FinderTask(ConcurrentRunner): def __init__(self, collection, iterations): super().__init__() self.daemon = True @@ -109,17 +115,17 @@ def run(self): class TestProse(IntegrationTest): def frequencies(self, client, listener, n_finds=10): coll = client.test.test - N_THREADS = 10 - threads = [FinderThread(coll, n_finds) for _ in range(N_THREADS)] - for thread in threads: - thread.start() - for thread in threads: - thread.join() - for thread in threads: - self.assertTrue(thread.passed) + N_TASKS = 10 + tasks = [FinderTask(coll, n_finds) for _ in range(N_TASKS)] + for task in tasks: + task.start() + for task in tasks: + task.join() + for task in tasks: + self.assertTrue(task.passed) events = listener.started_events - self.assertEqual(len(events), n_finds * N_THREADS) + self.assertEqual(len(events), n_finds * N_TASKS) nodes = client.nodes self.assertEqual(len(nodes), 2) freqs = {address: 0.0 for address in nodes} diff --git a/test/test_server_selection_logging.py b/test/test_server_selection_logging.py index 2df749cb10..d53d8dc84f 100644 --- a/test/test_server_selection_logging.py +++ b/test/test_server_selection_logging.py @@ -17,19 +17,25 @@ import os import sys +from pathlib import Path sys.path[0:0] = [""] from test import unittest from test.unified_format import generate_test_classes +_IS_SYNC = True + # Location of JSON test specifications. -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "server_selection_logging") +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection_logging") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "server_selection_logging") globals().update( generate_test_classes( - _TEST_PATH, + TEST_PATH, module=__name__, ) ) diff --git a/test/test_server_selection_rtt.py b/test/test_server_selection_rtt.py index a129af4585..2aef36a585 100644 --- a/test/test_server_selection_rtt.py +++ b/test/test_server_selection_rtt.py @@ -18,18 +18,24 @@ import json import os import sys +from pathlib import Path sys.path[0:0] = [""] -from test import unittest +from test import PyMongoTestCase, unittest from pymongo.read_preferences import MovingAverage +_IS_SYNC = True + # Location of JSON test specifications. -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "server_selection/rtt") +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection/rtt") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "server_selection/rtt") -class TestAllScenarios(unittest.TestCase): +class TestAllScenarios(PyMongoTestCase): pass @@ -49,7 +55,7 @@ def run_scenario(self): def create_tests(): - for dirpath, _, filenames in os.walk(_TEST_PATH): + for dirpath, _, filenames in os.walk(TEST_PATH): dirname = os.path.split(dirpath)[-1] for filename in filenames: diff --git a/test/test_session.py b/test/test_session.py index 634efa11c0..a6266884aa 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -15,10 +15,13 @@ """Test the client_session module.""" from __future__ import annotations +import asyncio import copy import sys import time +from asyncio import iscoroutinefunction from io import BytesIO +from test.helpers import ExceptionCatchingTask from typing import Any, Callable, List, Set, Tuple from pymongo.synchronous.mongo_client import MongoClient @@ -27,22 +30,22 @@ from test import ( IntegrationTest, - PyMongoTestCase, SkipTest, UnitTest, client_context, unittest, ) -from test.utils import ( +from test.helpers import client_knobs +from test.utils_shared import ( EventListener, - ExceptionCatchingThread, + HeartbeatEventListener, OvertCommandListener, wait_until, ) from bson import DBRef from gridfs.synchronous.grid_file import GridFS, GridFSBucket -from pymongo import ASCENDING, MongoClient, monitoring +from pymongo import ASCENDING, MongoClient, _csot, monitoring from pymongo.common import _MAX_END_SESSIONS from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure from pymongo.operations import IndexModel, InsertOne, UpdateOne @@ -184,7 +187,6 @@ def _test_ops(self, client, *ops): f"{f.__name__} did not return implicit session to pool", ) - @client_context.require_sync def test_implicit_sessions_checkout(self): # "To confirm that implicit sessions only allocate their server session after a # successful connection checkout" test from Driver Sessions Spec. @@ -210,25 +212,26 @@ def test_implicit_sessions_checkout(self): (cursor.distinct, ["_id"]), (client.db.list_collections, []), ] - threads = [] + tasks = [] listener.reset() - def thread_target(op, *args): - res = op(*args) + def target(op, *args): + if iscoroutinefunction(op): + res = op(*args) + else: + res = op(*args) if isinstance(res, (Cursor, CommandCursor)): - list(res) # type: ignore[call-overload] + res.to_list() for op, args in ops: - threads.append( - ExceptionCatchingThread( - target=thread_target, args=[op, *args], name=op.__name__ - ) + tasks.append( + ExceptionCatchingTask(target=target, args=[op, *args], name=op.__name__) ) - threads[-1].start() - self.assertEqual(len(threads), len(ops)) - for thread in threads: - thread.join() - self.assertIsNone(thread.exc) + tasks[-1].start() + self.assertEqual(len(tasks), len(ops)) + for t in tasks: + t.join() + self.assertIsNone(t.exc) client.close() lsid_set.clear() for i in listener.started_events: @@ -538,9 +541,10 @@ def find(session=None): (bucket.download_to_stream_by_name, ["f", sio], {}), (find, [], {}), (bucket.rename, [1, "f2"], {}), + (bucket.rename_by_name, ["f2", "f3"], {}), # Delete both files so _test_ops can run these operations twice. (bucket.delete, [1], {}), - (bucket.delete, [2], {}), + (bucket.delete_by_name, ["f"], {}), ) def test_gridfsbucket_cursor(self): @@ -1119,10 +1123,10 @@ def setUp(self): if "$clusterTime" not in (client_context.hello): raise SkipTest("$clusterTime not supported") + # Sessions prose test: 3) $clusterTime in commands def test_cluster_time(self): listener = SessionTestListener() - # Prevent heartbeats from updating $clusterTime between operations. - client = self.rs_or_single_client(event_listeners=[listener], heartbeatFrequencyMS=999999) + client = self.rs_or_single_client(event_listeners=[listener]) collection = client.pymongo_test.collection # Prepare for tests of find() and aggregate(). collection.insert_many([{} for _ in range(10)]) @@ -1201,6 +1205,38 @@ def aggregate(): f"{f.__name__} sent wrong $clusterTime with {event.command_name}", ) + # Sessions prose test: 20) Drivers do not gossip `$clusterTime` on SDAM commands + def test_cluster_time_not_used_by_sdam(self): + heartbeat_listener = HeartbeatEventListener() + cmd_listener = OvertCommandListener() + with client_knobs(min_heartbeat_interval=0.01): + c1 = self.single_client( + event_listeners=[heartbeat_listener, cmd_listener], heartbeatFrequencyMS=10 + ) + cluster_time = (c1.admin.command({"ping": 1}))["$clusterTime"] + self.assertEqual(c1._topology.max_cluster_time(), cluster_time) + + # Advance the server's $clusterTime by performing an insert via another client. + self.db.test.insert_one({"advance": "$clusterTime"}) + # Wait until the client C1 processes the next pair of SDAM heartbeat started + succeeded events. + heartbeat_listener.reset() + + def next_heartbeat(): + events = heartbeat_listener.events + for i in range(len(events) - 1): + if isinstance(events[i], monitoring.ServerHeartbeatStartedEvent): + if isinstance(events[i + 1], monitoring.ServerHeartbeatSucceededEvent): + return True + return False + + wait_until(next_heartbeat, "never found pair of heartbeat started + succeeded events") + # Assert that C1's max $clusterTime is still the same and has not been updated by SDAM. + cmd_listener.reset() + c1.admin.command({"ping": 1}) + started = cmd_listener.started_events[0] + self.assertEqual(started.command_name, "ping") + self.assertEqual(started.command["$clusterTime"], cluster_time) + if __name__ == "__main__": unittest.main() diff --git a/test/test_sessions_unified.py b/test/test_sessions_unified.py index c51b4642e7..3c80c70d38 100644 --- a/test/test_sessions_unified.py +++ b/test/test_sessions_unified.py @@ -17,14 +17,21 @@ import os import sys +from pathlib import Path sys.path[0:0] = [""] from test import unittest from test.unified_format import generate_test_classes +_IS_SYNC = True + # Location of JSON test specifications. -TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sessions") +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "sessions") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "sessions") + # Generate unified tests. globals().update(generate_test_classes(TEST_PATH, module=__name__)) diff --git a/test/test_srv_polling.py b/test/test_srv_polling.py index e01552bf7d..df802acb43 100644 --- a/test/test_srv_polling.py +++ b/test/test_srv_polling.py @@ -15,20 +15,23 @@ """Run the SRV support tests.""" from __future__ import annotations +import asyncio import sys -from time import sleep +import time +from test.utils_shared import FunctionCallRecorder from typing import Any sys.path[0:0] = [""] from test import PyMongoTestCase, client_knobs, unittest -from test.utils import FunctionCallRecorder, wait_until +from test.utils import wait_until import pymongo from pymongo import common from pymongo.errors import ConfigurationError -from pymongo.srv_resolver import _have_dnspython -from pymongo.synchronous.mongo_client import MongoClient +from pymongo.synchronous.srv_resolver import _have_dnspython + +_IS_SYNC = True WAIT_TIME = 0.1 @@ -51,7 +54,9 @@ def __init__( def enable(self): self.old_min_srv_rescan_interval = common.MIN_SRV_RESCAN_INTERVAL - self.old_dns_resolver_response = pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl + self.old_dns_resolver_response = ( + pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl + ) if self.min_srv_rescan_interval is not None: common.MIN_SRV_RESCAN_INTERVAL = self.min_srv_rescan_interval @@ -71,14 +76,14 @@ def mock_get_hosts_and_min_ttl(resolver, *args): else: patch_func = mock_get_hosts_and_min_ttl - pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func # type: ignore + pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func # type: ignore def __enter__(self): self.enable() def disable(self): common.MIN_SRV_RESCAN_INTERVAL = self.old_min_srv_rescan_interval # type: ignore - pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = ( # type: ignore + pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl = ( # type: ignore self.old_dns_resolver_response ) @@ -131,7 +136,10 @@ def assert_nodelist_nochange(self, expected_nodelist, client, timeout=(100 * WAI def predicate(): if set(expected_nodelist) == set(self.get_nodelist(client)): - return pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count >= 1 + return ( + pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count + >= 1 + ) return False wait_until(predicate, "Node list equals expected nodelist", timeout=timeout) @@ -141,7 +149,7 @@ def predicate(): msg = "Client nodelist %s changed unexpectedly (expected %s)" raise self.fail(msg % (nodelist, expected_nodelist)) self.assertGreaterEqual( - pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, # type: ignore + pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, # type: ignore 1, "resolver was never called", ) @@ -168,6 +176,7 @@ def dns_resolver_response(): # Patch timeouts to ensure short test running times. with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): client = self.simple_client(self.CONNECTION_STRING) + client._connect() self.assert_nodelist_change(self.BASE_SRV_RESPONSE, client) # Patch list of hosts returned by DNS query. with SrvPollingKnobs( @@ -232,6 +241,7 @@ def final_callback(): ): # Client uses unpatched method to get initial nodelist client = self.simple_client(self.CONNECTION_STRING) + client._connect() # Invalid DNS resolver response should not change nodelist. self.assert_nodelist_nochange(self.BASE_SRV_RESPONSE, client) @@ -265,6 +275,7 @@ def nodelist_callback(): with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=0) + client._connect() with SrvPollingKnobs(nodelist_callback=nodelist_callback): self.assert_nodelist_change(response, client) @@ -279,6 +290,7 @@ def nodelist_callback(): with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=2) + client._connect() with SrvPollingKnobs(nodelist_callback=nodelist_callback): self.assert_nodelist_change(response, client) @@ -294,8 +306,9 @@ def nodelist_callback(): with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=2) + client._connect() with SrvPollingKnobs(nodelist_callback=nodelist_callback): - sleep(2 * common.MIN_SRV_RESCAN_INTERVAL) + time.sleep(2 * common.MIN_SRV_RESCAN_INTERVAL) final_topology = set(client.topology_description.server_descriptions()) self.assertIn(("localhost.test.build.10gen.cc", 27017), final_topology) self.assertEqual(len(final_topology), 2) @@ -303,8 +316,9 @@ def nodelist_callback(): def test_does_not_flipflop(self): with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=1) + client._connect() old = set(client.topology_description.server_descriptions()) - sleep(4 * WAIT_TIME) + time.sleep(4 * WAIT_TIME) new = set(client.topology_description.server_descriptions()) self.assertSetEqual(old, new) @@ -322,6 +336,7 @@ def nodelist_callback(): client = self.simple_client( "mongodb+srv://test22.test.build.10gen.cc/?srvServiceName=customname" ) + client._connect() with SrvPollingKnobs(nodelist_callback=nodelist_callback): self.assert_nodelist_change(response, client) @@ -337,9 +352,9 @@ def resolver_response(): nodelist_callback=resolver_response, ): client = self.simple_client(self.CONNECTION_STRING) - self.assertRaises( - AssertionError, self.assert_nodelist_change, modified, client, timeout=WAIT_TIME / 2 - ) + client._connect() + with self.assertRaises(AssertionError): + self.assert_nodelist_change(modified, client, timeout=WAIT_TIME / 2) def test_import_dns_resolver(self): # Regression test for PYTHON-4407 diff --git a/test/test_ssl.py b/test/test_ssl.py index 04db9b61a4..9495a54364 100644 --- a/test/test_ssl.py +++ b/test/test_ssl.py @@ -16,6 +16,7 @@ from __future__ import annotations import os +import pathlib import socket import sys @@ -31,7 +32,7 @@ remove_all_users, unittest, ) -from test.utils import ( +from test.utils_shared import ( EventListener, OvertCommandListener, cat_files, @@ -42,7 +43,7 @@ from pymongo import MongoClient, ssl_support from pymongo.errors import ConfigurationError, ConnectionFailure, OperationFailure from pymongo.hello import HelloCompat -from pymongo.ssl_support import HAVE_SSL, _ssl, get_ssl_context +from pymongo.ssl_support import HAVE_PYSSL, HAVE_SSL, _ssl, get_ssl_context from pymongo.write_concern import WriteConcern _HAVE_PYOPENSSL = False @@ -65,7 +66,13 @@ if HAVE_SSL: import ssl -CERT_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "certificates") +_IS_SYNC = True + +if _IS_SYNC: + CERT_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "certificates") +else: + CERT_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "certificates") + CLIENT_PEM = os.path.join(CERT_PATH, "client.pem") CLIENT_ENCRYPTED_PEM = os.path.join(CERT_PATH, "password_protected.pem") CA_PEM = os.path.join(CERT_PATH, "ca.pem") @@ -127,7 +134,7 @@ def test_config_ssl(self): @unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.") def test_use_pyopenssl_when_available(self): - self.assertTrue(_ssl.IS_PYOPENSSL) + self.assertTrue(HAVE_PYSSL) @unittest.skipUnless(_HAVE_PYOPENSSL, "Cannot test without PyOpenSSL") def test_load_trusted_ca_certs(self): @@ -144,21 +151,18 @@ def assertClientWorks(self, client): ) coll.drop() coll.insert_one({"ssl": True}) - self.assertTrue(coll.find_one()["ssl"]) + self.assertTrue((coll.find_one())["ssl"]) coll.drop() - @classmethod @unittest.skipUnless(HAVE_SSL, "The ssl module is not available.") - def setUpClass(cls): - super().setUpClass() + def setUp(self): + super().setUp() # MongoClient should connect to the primary by default. - cls.saved_port = MongoClient.PORT + self.saved_port = MongoClient.PORT MongoClient.PORT = client_context.port - @classmethod - def tearDownClass(cls): - MongoClient.PORT = cls.saved_port - super().tearDownClass() + def tearDown(self): + MongoClient.PORT = self.saved_port @client_context.require_tls def test_simple_ssl(self): @@ -173,7 +177,7 @@ def test_tlsCertificateKeyFilePassword(self): # # --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem # --sslCAFile=/path/to/pymongo/test/certificates/ca.pem - if not hasattr(ssl, "SSLContext") and not _ssl.IS_PYOPENSSL: + if not hasattr(ssl, "SSLContext") and not HAVE_PYSSL: self.assertRaises( ConfigurationError, self.simple_client, @@ -305,13 +309,13 @@ def test_cert_ssl_validation_hostname_matching(self): # # --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem # --sslCAFile=/path/to/pymongo/test/certificates/ca.pem - ctx = get_ssl_context(None, None, None, None, True, True, False) + ctx = get_ssl_context(None, None, None, None, True, True, False, _IS_SYNC) self.assertFalse(ctx.check_hostname) - ctx = get_ssl_context(None, None, None, None, True, False, False) + ctx = get_ssl_context(None, None, None, None, True, False, False, _IS_SYNC) self.assertFalse(ctx.check_hostname) - ctx = get_ssl_context(None, None, None, None, False, True, False) + ctx = get_ssl_context(None, None, None, None, False, True, False, _IS_SYNC) self.assertFalse(ctx.check_hostname) - ctx = get_ssl_context(None, None, None, None, False, False, False) + ctx = get_ssl_context(None, None, None, None, False, False, False, _IS_SYNC) self.assertTrue(ctx.check_hostname) response = self.client.admin.command(HelloCompat.LEGACY_CMD) @@ -372,9 +376,11 @@ def test_cert_ssl_validation_hostname_matching(self): ) @client_context.require_tlsCertificateKeyFile + @client_context.require_sync + @client_context.require_no_api_version @ignore_deprecations def test_tlsCRLFile_support(self): - if not hasattr(ssl, "VERIFY_CRL_CHECK_LEAF") or _ssl.IS_PYOPENSSL: + if not hasattr(ssl, "VERIFY_CRL_CHECK_LEAF") or HAVE_PYSSL: self.assertRaises( ConfigurationError, self.simple_client, @@ -465,7 +471,7 @@ def test_validation_with_system_ca_certs(self): ) def test_system_certs_config_error(self): - ctx = get_ssl_context(None, None, None, None, True, True, False) + ctx = get_ssl_context(None, None, None, None, True, True, False, _IS_SYNC) if (sys.platform != "win32" and hasattr(ctx, "set_default_verify_paths")) or hasattr( ctx, "load_default_certs" ): @@ -496,11 +502,11 @@ def test_certifi_support(self): # Force the test on Windows, regardless of environment. ssl_support.HAVE_WINCERTSTORE = False try: - ctx = get_ssl_context(None, None, CA_PEM, None, False, False, False) + ctx = get_ssl_context(None, None, CA_PEM, None, False, False, False, _IS_SYNC) ssl_sock = ctx.wrap_socket(socket.socket()) self.assertEqual(ssl_sock.ca_certs, CA_PEM) - ctx = get_ssl_context(None, None, None, None, False, False, False) + ctx = get_ssl_context(None, None, None, None, False, False, False, _IS_SYNC) ssl_sock = ctx.wrap_socket(socket.socket()) self.assertEqual(ssl_sock.ca_certs, ssl_support.certifi.where()) finally: @@ -517,11 +523,11 @@ def test_wincertstore(self): if not ssl_support.HAVE_WINCERTSTORE: raise SkipTest("Need wincertstore to test wincertstore.") - ctx = get_ssl_context(None, None, CA_PEM, None, False, False, False) + ctx = get_ssl_context(None, None, CA_PEM, None, False, False, False, _IS_SYNC) ssl_sock = ctx.wrap_socket(socket.socket()) self.assertEqual(ssl_sock.ca_certs, CA_PEM) - ctx = get_ssl_context(None, None, None, None, False, False, False) + ctx = get_ssl_context(None, None, None, None, False, False, False, _IS_SYNC) ssl_sock = ctx.wrap_socket(socket.socket()) self.assertEqual(ssl_sock.ca_certs, ssl_support._WINCERTS.name) @@ -548,7 +554,6 @@ def test_mongodb_x509_auth(self): tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM, ) - self.addCleanup(noauth.close) with self.assertRaises(OperationFailure): noauth.pymongo_test.test.find_one() @@ -562,7 +567,6 @@ def test_mongodb_x509_auth(self): tlsCertificateKeyFile=CLIENT_PEM, event_listeners=[listener], ) - self.addCleanup(auth.close) # No error auth.pymongo_test.test.find_one() @@ -581,7 +585,6 @@ def test_mongodb_x509_auth(self): client = self.simple_client( uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM ) - self.addCleanup(client.close) # No error client.pymongo_test.test.find_one() @@ -589,7 +592,6 @@ def test_mongodb_x509_auth(self): client = self.simple_client( uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM ) - self.addCleanup(client.close) # No error client.pymongo_test.test.find_one() # Auth should fail if username and certificate do not match @@ -602,7 +604,6 @@ def test_mongodb_x509_auth(self): bad_client = self.simple_client( uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM ) - self.addCleanup(bad_client.close) with self.assertRaises(OperationFailure): bad_client.pymongo_test.test.find_one() @@ -615,7 +616,6 @@ def test_mongodb_x509_auth(self): tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM, ) - self.addCleanup(bad_client.close) with self.assertRaises(OperationFailure): bad_client.pymongo_test.test.find_one() @@ -659,6 +659,14 @@ def remove(path): ) as client: self.assertTrue(client.admin.command("ping")) + @client_context.require_async + @unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.") + @unittest.skipUnless(HAVE_SSL, "The ssl module is not available.") + def test_pyopenssl_ignored_in_async(self): + client = MongoClient("mongodb://localhost:27017?tls=true&tlsAllowInvalidCertificates=true") + client.admin.command("ping") # command doesn't matter, just needs it to connect + client.close() + if __name__ == "__main__": unittest.main() diff --git a/test/test_streaming_protocol.py b/test/test_streaming_protocol.py index d782aa1dd7..acf7610c94 100644 --- a/test/test_streaming_protocol.py +++ b/test/test_streaming_protocol.py @@ -21,7 +21,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import ( +from test.utils_shared import ( HeartbeatEventListener, ServerEventListener, wait_until, @@ -30,6 +30,8 @@ from pymongo import monitoring from pymongo.hello import HelloCompat +_IS_SYNC = True + class TestStreamingProtocol(IntegrationTest): @client_context.require_failCommand_appName @@ -41,7 +43,6 @@ def test_failCommand_streaming(self): heartbeatFrequencyMS=500, appName="failingHeartbeatTest", ) - self.addCleanup(client.close) # Force a connection. client.admin.command("ping") address = client.address @@ -78,7 +79,7 @@ def marked_unknown(): def rediscovered(): return len(listener.matching(_discovered_node)) >= 1 - # Topology events are published asynchronously + # Topology events are not published synchronously wait_until(marked_unknown, "mark node unknown") wait_until(rediscovered, "rediscover node") @@ -108,7 +109,6 @@ def test_streaming_rtt(self): client = self.rs_or_single_client( event_listeners=[listener, hb_listener], heartbeatFrequencyMS=500, appName=name ) - self.addCleanup(client.close) # Force a connection. client.admin.command("ping") address = client.address @@ -156,7 +156,6 @@ def test_monitor_waits_after_server_check_error(self): client = self.single_client( appName="SDAMMinHeartbeatFrequencyTest", serverSelectionTimeoutMS=5000 ) - self.addCleanup(client.close) # Force a connection. client.admin.command("ping") duration = time.time() - start @@ -183,7 +182,6 @@ def test_heartbeat_awaited_flag(self): heartbeatFrequencyMS=500, appName="heartbeatEventAwaitedFlag", ) - self.addCleanup(client.close) # Force a connection. client.admin.command("ping") diff --git a/test/test_topology.py b/test/test_topology.py index 86aa87c2cc..22e94739ee 100644 --- a/test/test_topology.py +++ b/test/test_topology.py @@ -23,7 +23,8 @@ from test import client_knobs, unittest from test.pymongo_mocks import DummyMonitor -from test.utils import MockPool, wait_until +from test.utils import MockPool +from test.utils_shared import wait_until from bson.objectid import ObjectId from pymongo import common diff --git a/test/test_transactions.py b/test/test_transactions.py index 949b88e60b..63ea5c74fe 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -24,7 +24,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import ( +from test.utils_shared import ( OvertCommandListener, wait_until, ) @@ -32,7 +32,7 @@ from bson import encode from bson.raw_bson import RawBSONDocument -from pymongo import WriteConcern +from pymongo import WriteConcern, _csot from pymongo.errors import ( CollectionInvalid, ConfigurationError, @@ -287,6 +287,14 @@ def gridfs_open_upload_stream(*args, **kwargs): "new-name", ), ), + ( + bucket.rename_by_name, + ( + "new-name", + "new-name2", + ), + ), + (bucket.delete_by_name, ("new-name2",)), ] with client.start_session() as s, s.start_transaction(): @@ -402,15 +410,10 @@ def setUp(self) -> None: for address in client_context.mongoses: self.mongos_clients.append(self.single_client("{}:{}".format(*address))) - def _set_fail_point(self, client, command_args): - cmd = {"configureFailPoint": "failCommand"} - cmd.update(command_args) - client.admin.command(cmd) - def set_fail_point(self, command_args): clients = self.mongos_clients if self.mongos_clients else [self.client] for client in clients: - self._set_fail_point(client, command_args) + self.configure_fail_point(client, command_args) @client_context.require_transactions def test_callback_raises_custom_error(self): @@ -571,5 +574,29 @@ def callback(session): self.assertFalse(s.in_transaction) +class TestOptionsInsideTransactionProse(TransactionsBase): + @client_context.require_transactions + @client_context.require_no_standalone + def test_case_1(self): + # Write concern not inherited from collection object inside transaction + # Create a MongoClient running against a configured sharded/replica set/load balanced cluster. + client = client_context.client + coll = client[self.db.name].test + coll.delete_many({}) + # Start a new session on the client. + with client.start_session() as s: + # Start a transaction on the session. + s.start_transaction() + # Instantiate a collection object in the driver with a default write concern of { w: 0 }. + inner_coll = coll.with_options(write_concern=WriteConcern(w=0)) + # Insert the document { n: 1 } on the instantiated collection. + result = inner_coll.insert_one({"n": 1}, session=s) + # Commit the transaction. + s.commit_transaction() + # End the session. + # Ensure the document was inserted and no error was thrown from the transaction. + assert result.inserted_id is not None + + if __name__ == "__main__": unittest.main() diff --git a/test/test_transactions_unified.py b/test/test_transactions_unified.py index 81137bf658..641e05108a 100644 --- a/test/test_transactions_unified.py +++ b/test/test_transactions_unified.py @@ -17,12 +17,15 @@ import os import sys +from pathlib import Path sys.path[0:0] = [""] from test import client_context, unittest from test.unified_format import generate_test_classes +_IS_SYNC = True + @client_context.require_no_mmap def setUpModule(): @@ -30,15 +33,21 @@ def setUpModule(): # Location of JSON test specifications. -TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "transactions", "unified") +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "transactions/unified") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "transactions/unified") # Generate unified tests. globals().update(generate_test_classes(TEST_PATH, module=__name__)) # Location of JSON test specifications for transactions-convenient-api. -TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "transactions-convenient-api", "unified" -) +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "transactions-convenient-api/unified") +else: + TEST_PATH = os.path.join( + Path(__file__).resolve().parent.parent, "transactions-convenient-api/unified" + ) # Generate unified tests. globals().update(generate_test_classes(TEST_PATH, module=__name__)) diff --git a/test/test_unified_format.py b/test/test_unified_format.py index 1b3a134237..05f58d5d06 100644 --- a/test/test_unified_format.py +++ b/test/test_unified_format.py @@ -15,21 +15,28 @@ import os import sys +from pathlib import Path from typing import Any sys.path[0:0] = [""] -from test import unittest +from test import UnitTest, unittest from test.unified_format import MatchEvaluatorUtil, generate_test_classes from bson import ObjectId -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "unified-test-format") +_IS_SYNC = True + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "unified-test-format") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "unified-test-format") globals().update( generate_test_classes( - os.path.join(_TEST_PATH, "valid-pass"), + os.path.join(TEST_PATH, "valid-pass"), module=__name__, class_name_prefix="UnifiedTestFormat", expected_failures=[ @@ -42,7 +49,7 @@ globals().update( generate_test_classes( - os.path.join(_TEST_PATH, "valid-fail"), + os.path.join(TEST_PATH, "valid-fail"), module=__name__, class_name_prefix="UnifiedTestFormat", bypass_test_generation_errors=True, @@ -54,7 +61,7 @@ ) -class TestMatchEvaluatorUtil(unittest.TestCase): +class TestMatchEvaluatorUtil(UnitTest): def setUp(self): self.match_evaluator = MatchEvaluatorUtil(self) diff --git a/test/test_uri_parser.py b/test/test_uri_parser.py index f95717e95f..d4d17ac211 100644 --- a/test/test_uri_parser.py +++ b/test/test_uri_parser.py @@ -24,12 +24,13 @@ sys.path[0:0] = [""] from test import unittest +from unittest.mock import patch from bson.binary import JAVA_LEGACY from pymongo import ReadPreference from pymongo.errors import ConfigurationError, InvalidURI -from pymongo.uri_parser import ( - parse_uri, +from pymongo.synchronous.uri_parser import parse_uri +from pymongo.uri_parser_shared import ( parse_userinfo, split_hosts, split_options, diff --git a/test/test_uri_spec.py b/test/test_uri_spec.py index 29cde7e078..aeb0be94b5 100644 --- a/test/test_uri_spec.py +++ b/test/test_uri_spec.py @@ -29,7 +29,7 @@ from pymongo.common import INTERNAL_URI_OPTION_NAME_MAP, validate from pymongo.compression_support import _have_snappy -from pymongo.uri_parser import parse_uri +from pymongo.synchronous.uri_parser import parse_uri CONN_STRING_TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), os.path.join("connection_string", "test") diff --git a/test/test_versioned_api.py b/test/test_versioned_api.py index 7a25a507dc..19b125770f 100644 --- a/test/test_versioned_api.py +++ b/test/test_versioned_api.py @@ -13,28 +13,18 @@ # limitations under the License. from __future__ import annotations -import os import sys +from test import UnitTest sys.path[0:0] = [""] -from test import IntegrationTest, client_context, unittest -from test.unified_format import generate_test_classes -from test.utils import OvertCommandListener +from test import unittest +from pymongo.mongo_client import MongoClient from pymongo.server_api import ServerApi, ServerApiVersion -from pymongo.synchronous.mongo_client import MongoClient -TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "versioned-api") - -# Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) - - -class TestServerApi(IntegrationTest): - RUN_ON_LOAD_BALANCER = True - RUN_ON_SERVERLESS = True +class TestServerApi(UnitTest): def test_server_api_defaults(self): api = ServerApi(ServerApiVersion.V1) self.assertEqual(api.version, "1") @@ -74,35 +64,6 @@ def assertServerApiInAllCommands(self, events): for event in events: self.assertServerApi(event) - @client_context.require_version_min(4, 7) - def test_command_options(self): - listener = OvertCommandListener() - client = self.rs_or_single_client(server_api=ServerApi("1"), event_listeners=[listener]) - self.addCleanup(client.close) - coll = client.test.test - coll.insert_many([{} for _ in range(100)]) - self.addCleanup(coll.delete_many, {}) - list(coll.find(batch_size=25)) - client.admin.command("ping") - self.assertServerApiInAllCommands(listener.started_events) - - @client_context.require_version_min(4, 7) - @client_context.require_transactions - def test_command_options_txn(self): - listener = OvertCommandListener() - client = self.rs_or_single_client(server_api=ServerApi("1"), event_listeners=[listener]) - self.addCleanup(client.close) - coll = client.test.test - coll.insert_many([{} for _ in range(100)]) - self.addCleanup(coll.delete_many, {}) - - listener.reset() - with client.start_session() as s, s.start_transaction(): - coll.insert_many([{} for _ in range(100)], session=s) - list(coll.find(batch_size=25, session=s)) - client.test.command("find", "test", session=s) - self.assertServerApiInAllCommands(listener.started_events) - if __name__ == "__main__": unittest.main() diff --git a/test/test_versioned_api_integration.py b/test/test_versioned_api_integration.py new file mode 100644 index 0000000000..0066ecd977 --- /dev/null +++ b/test/test_versioned_api_integration.py @@ -0,0 +1,82 @@ +# Copyright 2020-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import os +import sys +from pathlib import Path +from test.unified_format import generate_test_classes + +sys.path[0:0] = [""] + +from test import IntegrationTest, client_context, unittest +from test.utils_shared import OvertCommandListener + +from pymongo.server_api import ServerApi + +_IS_SYNC = True + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "versioned-api") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "versioned-api") + + +# Generate unified tests. +globals().update(generate_test_classes(TEST_PATH, module=__name__)) + + +class TestServerApiIntegration(IntegrationTest): + RUN_ON_LOAD_BALANCER = True + RUN_ON_SERVERLESS = True + + def assertServerApi(self, event): + self.assertIn("apiVersion", event.command) + self.assertEqual(event.command["apiVersion"], "1") + + def assertServerApiInAllCommands(self, events): + for event in events: + self.assertServerApi(event) + + @client_context.require_version_min(4, 7) + def test_command_options(self): + listener = OvertCommandListener() + client = self.rs_or_single_client(server_api=ServerApi("1"), event_listeners=[listener]) + coll = client.test.test + coll.insert_many([{} for _ in range(100)]) + self.addCleanup(coll.delete_many, {}) + coll.find(batch_size=25).to_list() + client.admin.command("ping") + self.assertServerApiInAllCommands(listener.started_events) + + @client_context.require_version_min(4, 7) + @client_context.require_transactions + def test_command_options_txn(self): + listener = OvertCommandListener() + client = self.rs_or_single_client(server_api=ServerApi("1"), event_listeners=[listener]) + coll = client.test.test + coll.insert_many([{} for _ in range(100)]) + self.addCleanup(coll.delete_many, {}) + + listener.reset() + with client.start_session() as s, s.start_transaction(): + coll.insert_many([{} for _ in range(100)], session=s) + coll.find(batch_size=25, session=s).to_list() + client.test.command("find", "test", session=s) + self.assertServerApiInAllCommands(listener.started_events) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/transactions-convenient-api/unified/commit-writeconcernerror.json b/test/transactions-convenient-api/unified/commit-writeconcernerror.json index a6f6e6bd7f..568f7ede42 100644 --- a/test/transactions-convenient-api/unified/commit-writeconcernerror.json +++ b/test/transactions-convenient-api/unified/commit-writeconcernerror.json @@ -56,7 +56,7 @@ ], "tests": [ { - "description": "commitTransaction is retried after WriteConcernFailed timeout error", + "description": "commitTransaction is retried after WriteConcernTimeout timeout error", "operations": [ { "name": "failPoint", @@ -74,7 +74,6 @@ ], "writeConcernError": { "code": 64, - "codeName": "WriteConcernFailed", "errmsg": "waiting for replication timed out", "errInfo": { "wtimeout": true @@ -236,7 +235,7 @@ ] }, { - "description": "commitTransaction is retried after WriteConcernFailed non-timeout error", + "description": "commitTransaction is retried after WriteConcernTimeout non-timeout error", "operations": [ { "name": "failPoint", @@ -254,7 +253,6 @@ ], "writeConcernError": { "code": 64, - "codeName": "WriteConcernFailed", "errmsg": "multiple errors reported" } } diff --git a/test/transactions/unified/error-labels.json b/test/transactions/unified/error-labels.json index be8df10ed3..74ed750b07 100644 --- a/test/transactions/unified/error-labels.json +++ b/test/transactions/unified/error-labels.json @@ -1176,7 +1176,7 @@ ] }, { - "description": "add UnknownTransactionCommitResult label to writeConcernError WriteConcernFailed", + "description": "add UnknownTransactionCommitResult label to writeConcernError WriteConcernTimeout", "operations": [ { "object": "testRunner", @@ -1338,7 +1338,7 @@ ] }, { - "description": "add UnknownTransactionCommitResult label to writeConcernError WriteConcernFailed with wtimeout", + "description": "add UnknownTransactionCommitResult label to writeConcernError WriteConcernTimeout with wtimeout", "operations": [ { "object": "testRunner", @@ -1356,7 +1356,6 @@ ], "writeConcernError": { "code": 64, - "codeName": "WriteConcernFailed", "errmsg": "waiting for replication timed out", "errInfo": { "wtimeout": true diff --git a/test/unified-test-format/valid-pass/expectedError-isClientError.json b/test/unified-test-format/valid-pass/expectedError-isClientError.json new file mode 100644 index 0000000000..9c6beda588 --- /dev/null +++ b/test/unified-test-format/valid-pass/expectedError-isClientError.json @@ -0,0 +1,74 @@ +{ + "description": "expectedError-isClientError", + "schemaVersion": "1.3", + "runOnRequirements": [ + { + "minServerVersion": "4.0", + "topologies": [ + "single", + "replicaset" + ] + }, + { + "minServerVersion": "4.1.7", + "topologies": [ + "sharded", + "load-balanced" + ] + } + ], + "createEntities": [ + { + "client": { + "id": "client0", + "useMultipleMongoses": false + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "test" + } + } + ], + "tests": [ + { + "description": "isClientError considers network errors", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "ping" + ], + "closeConnection": true + } + } + } + }, + { + "name": "runCommand", + "object": "database0", + "arguments": { + "commandName": "ping", + "command": { + "ping": 1 + } + }, + "expectError": { + "isClientError": true + } + } + ] + } + ] +} diff --git a/test/unified-test-format/valid-pass/operation-empty_array.json b/test/unified-test-format/valid-pass/operation-empty_array.json new file mode 100644 index 0000000000..93b25c983c --- /dev/null +++ b/test/unified-test-format/valid-pass/operation-empty_array.json @@ -0,0 +1,10 @@ +{ + "description": "operation-empty_array", + "schemaVersion": "1.0", + "tests": [ + { + "description": "Empty operations array", + "operations": [] + } + ] +} diff --git a/test/unified-test-format/valid-pass/operator-lte.json b/test/unified-test-format/valid-pass/operator-lte.json index 4a13b16d15..7a6a8057ad 100644 --- a/test/unified-test-format/valid-pass/operator-lte.json +++ b/test/unified-test-format/valid-pass/operator-lte.json @@ -42,7 +42,9 @@ "arguments": { "document": { "_id": 1, - "y": 1 + "x": 2, + "y": 3, + "z": 4 } } } @@ -58,10 +60,18 @@ "documents": [ { "_id": { - "$$lte": 1 + "$$lte": 2 + }, + "x": { + "$$lte": 2.1 }, "y": { - "$$lte": 2 + "$$lte": { + "$numberLong": "3" + } + }, + "z": { + "$$lte": 4 } } ] diff --git a/test/unified-test-format/valid-pass/operator-type-number_alias.json b/test/unified-test-format/valid-pass/operator-type-number_alias.json new file mode 100644 index 0000000000..e628d0d777 --- /dev/null +++ b/test/unified-test-format/valid-pass/operator-type-number_alias.json @@ -0,0 +1,174 @@ +{ + "description": "operator-type-number_alias", + "schemaVersion": "1.0", + "createEntities": [ + { + "client": { + "id": "client0" + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "test" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "coll0" + } + } + ], + "initialData": [ + { + "collectionName": "coll0", + "databaseName": "test", + "documents": [] + } + ], + "tests": [ + { + "description": "type number alias matches int32", + "operations": [ + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 1, + "x": { + "$numberInt": "2147483647" + } + } + } + }, + { + "name": "find", + "object": "collection0", + "arguments": { + "filter": { + "_id": 1 + }, + "limit": 1 + }, + "expectResult": [ + { + "_id": 1, + "x": { + "$$type": "number" + } + } + ] + } + ] + }, + { + "description": "type number alias matches int64", + "operations": [ + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 1, + "x": { + "$numberLong": "9223372036854775807" + } + } + } + }, + { + "name": "find", + "object": "collection0", + "arguments": { + "filter": { + "_id": 1 + }, + "limit": 1 + }, + "expectResult": [ + { + "_id": 1, + "x": { + "$$type": "number" + } + } + ] + } + ] + }, + { + "description": "type number alias matches double", + "operations": [ + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 1, + "x": { + "$numberDouble": "2.71828" + } + } + } + }, + { + "name": "find", + "object": "collection0", + "arguments": { + "filter": { + "_id": 1 + }, + "limit": 1 + }, + "expectResult": [ + { + "_id": 1, + "x": { + "$$type": "number" + } + } + ] + } + ] + }, + { + "description": "type number alias matches decimal128", + "operations": [ + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 1, + "x": { + "$numberDecimal": "3.14159" + } + } + } + }, + { + "name": "find", + "object": "collection0", + "arguments": { + "filter": { + "_id": 1 + }, + "limit": 1 + }, + "expectResult": [ + { + "_id": 1, + "x": { + "$$type": "number" + } + } + ] + } + ] + } + ] +} diff --git a/test/unified_format.py b/test/unified_format.py index 372eb8abba..71d6cd50d4 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -48,10 +48,10 @@ parse_collection_or_database_options, with_metaclass, ) -from test.utils import ( +from test.utils import get_pool +from test.utils_shared import ( camel_to_snake, camel_to_snake_args, - get_pool, parse_spec_options, prepare_spec_arguments, snake_to_camel, @@ -65,7 +65,7 @@ from bson import SON, json_util from bson.codec_options import DEFAULT_CODEC_OPTIONS from bson.objectid import ObjectId -from gridfs import GridFSBucket, GridOut +from gridfs import GridFSBucket, GridOut, NoFile from pymongo import ASCENDING, CursorType, MongoClient, _csot from pymongo.encryption_options import _HAVE_PYMONGOCRYPT from pymongo.errors import ( @@ -221,7 +221,6 @@ def __init__(self, test_class): self._listeners: Dict[str, EventListenerUtil] = {} self._session_lsids: Dict[str, Mapping[str, Any]] = {} self.test: UnifiedSpecTestMixinV1 = test_class - self._cluster_time: Mapping[str, Any] = {} def __contains__(self, item): return item in self._entities @@ -377,12 +376,14 @@ def drop(self: GridFSBucket, *args: Any, **kwargs: Any) -> None: opts["key_vault_client"], DEFAULT_CODEC_OPTIONS, opts.get("kms_tls_options", kms_tls_options), + opts.get("key_expiration_ms"), ) return elif entity_type == "thread": name = spec["id"] thread = SpecRunnerThread(name) thread.start() + self.test.addCleanup(thread.join, 5) self[name] = thread return @@ -418,13 +419,11 @@ def get_lsid_for_session(self, session_name): # session has been closed. return self._session_lsids[session_name] - def advance_cluster_times(self) -> None: + def advance_cluster_times(self, cluster_time) -> None: """Manually synchronize entities when desired""" - if not self._cluster_time: - self._cluster_time = (self.test.client.admin.command("ping")).get("$clusterTime") for entity in self._entities.values(): - if isinstance(entity, ClientSession) and self._cluster_time: - entity.advance_cluster_time(self._cluster_time) + if isinstance(entity, ClientSession) and cluster_time: + entity.advance_cluster_time(cluster_time) class UnifiedSpecTestMixinV1(IntegrationTest): @@ -437,7 +436,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest): a class attribute ``TEST_SPEC``. """ - SCHEMA_VERSION = Version.from_string("1.21") + SCHEMA_VERSION = Version.from_string("1.22") RUN_ON_LOAD_BALANCER = True RUN_ON_SERVERLESS = True TEST_SPEC: Any @@ -543,6 +542,14 @@ def maybe_skip_test(self, spec): self.skipTest("Implement PYTHON-1894") if "timeoutMS applied to entire download" in spec["description"]: self.skipTest("PyMongo's open_download_stream does not cap the stream's lifetime") + if ( + "Error returned from connection pool clear with interruptInUseConnections=true is retryable" + in spec["description"] + and not _IS_SYNC + ): + self.skipTest("PYTHON-5170 tests are flakey") + if "Driver extends timeout while streaming" in spec["description"] and not _IS_SYNC: + self.skipTest("PYTHON-5174 tests are flakey") class_name = self.__class__.__name__.lower() description = spec["description"].lower() @@ -557,7 +564,11 @@ def maybe_skip_test(self, spec): self.skipTest("CSOT not implemented for watch()") if "cursors" in class_name: self.skipTest("CSOT not implemented for cursors") - if "tailable" in class_name: + if ( + "tailable" in class_name + or "tailable" in description + and "non-tailable" not in description + ): self.skipTest("CSOT not implemented for tailable cursors") if "sessions" in class_name: self.skipTest("CSOT not implemented for sessions") @@ -617,7 +628,7 @@ def process_error(self, exception, spec): # Connection errors are considered client errors. if isinstance(error, ConnectionFailure): self.assertNotIsInstance(error, NotPrimaryError) - elif isinstance(error, (InvalidOperation, ConfigurationError, EncryptionError)): + elif isinstance(error, (InvalidOperation, ConfigurationError, EncryptionError, NoFile)): pass else: self.assertNotIsInstance(error, PyMongoError) @@ -708,7 +719,7 @@ def _databaseOperation_runCommand(self, target, **kwargs): return target.command(**kwargs) def _databaseOperation_runCursorCommand(self, target, **kwargs): - return list(self._databaseOperation_createCommandCursor(target, **kwargs)) + return (self._databaseOperation_createCommandCursor(target, **kwargs)).to_list() def _databaseOperation_createCommandCursor(self, target, **kwargs): self.__raise_if_unsupported("createCommandCursor", target, Database) @@ -999,12 +1010,8 @@ def __set_fail_point(self, client, command_args): if not client_context.test_commands_enabled: self.skipTest("Test commands must be enabled") - cmd_on = SON([("configureFailPoint", "failCommand")]) - cmd_on.update(command_args) - client.admin.command(cmd_on) - self.addCleanup( - client.admin.command, "configureFailPoint", cmd_on["configureFailPoint"], mode="off" - ) + self.configure_fail_point(client, command_args) + self.addCleanup(self.configure_fail_point, client, command_args, off=True) def _testOperation_failPoint(self, spec): self.__set_fail_point( @@ -1025,7 +1032,7 @@ def _testOperation_targetedFailPoint(self, spec): def _testOperation_createEntities(self, spec): self.entity_map.create_entities_from_spec(spec["entities"], uri=self._uri) - self.entity_map.advance_cluster_times() + self.entity_map.advance_cluster_times(self._cluster_time) def _testOperation_assertSessionTransactionState(self, spec): session = self.entity_map[spec["session"]] @@ -1167,7 +1174,7 @@ def primary_changed() -> bool: def _testOperation_runOnThread(self, spec): """Run the 'runOnThread' operation.""" thread = self.entity_map[spec["thread"]] - thread.schedule(lambda: self.run_entity_operation(spec["operation"])) + thread.schedule(functools.partial(self.run_entity_operation, spec["operation"])) def _testOperation_waitForThread(self, spec): """Run the 'waitForThread' operation.""" @@ -1374,7 +1381,6 @@ def run_scenario(self, spec, uri=None): # transaction (from a test failure) from blocking collection/database # operations during test set up and tear down. self.kill_all_sessions() - self.addCleanup(self.kill_all_sessions) if "csot" in self.id().lower(): # Retry CSOT tests up to 2 times to deal with flakey tests. @@ -1382,7 +1388,11 @@ def run_scenario(self, spec, uri=None): for i in range(attempts): try: return self._run_scenario(spec, uri) - except AssertionError: + except (AssertionError, OperationFailure) as exc: + if isinstance(exc, OperationFailure) and ( + _IS_SYNC or "failpoint" not in exc._message + ): + raise if i < attempts - 1: print( f"Retrying after attempt {i+1} of {self.id()} failed with:\n" @@ -1415,11 +1425,12 @@ def _run_scenario(self, spec, uri=None): self._uri = uri self.entity_map = EntityMapUtil(self) self.entity_map.create_entities_from_spec(self.TEST_SPEC.get("createEntities", []), uri=uri) + self._cluster_time = None # process initialData if "initialData" in self.TEST_SPEC: self.insert_initial_data(self.TEST_SPEC["initialData"]) - self._cluster_time = (self.client.admin.command("ping")).get("$clusterTime") - self.entity_map.advance_cluster_times() + self._cluster_time = self.client._topology.max_cluster_time() + self.entity_map.advance_cluster_times(self._cluster_time) if "expectLogMessages" in spec: expect_log_messages = spec["expectLogMessages"] diff --git a/test/unified_format_shared.py b/test/unified_format_shared.py index 0c685366f4..ea0f2f233e 100644 --- a/test/unified_format_shared.py +++ b/test/unified_format_shared.py @@ -35,7 +35,7 @@ KMIP_CREDS, LOCAL_MASTER_KEY, ) -from test.utils import CMAPListener, camel_to_snake, parse_collection_options +from test.utils_shared import CMAPListener, camel_to_snake, parse_collection_options from typing import Any, Union from bson import ( @@ -363,6 +363,7 @@ def closed(self, event: Union[ServerClosedEvent, TopologyClosedEvent]) -> None: "decimal": (Decimal128,), "maxKey": (MaxKey,), "minKey": (MinKey,), + "number": (float, int, Int64, Decimal128), } diff --git a/test/utils.py b/test/utils.py index 69154bc63b..3027ed7517 100644 --- a/test/utils.py +++ b/test/utils.py @@ -12,418 +12,76 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Utilities for testing pymongo""" +"""Utilities for testing pymongo that require synchronization.""" from __future__ import annotations import asyncio import contextlib -import copy -import functools -import os import random -import re -import shutil -import sys -import threading +import threading # Used in the synchronized version of this file import time -import unittest -import warnings from asyncio import iscoroutinefunction -from collections import abc, defaultdict -from functools import partial -from test import client_context, db_pwd, db_user -from test.asynchronous import async_client_context -from typing import Any, List -from bson import json_util -from bson.objectid import ObjectId from bson.son import SON -from pymongo import AsyncMongoClient, monitoring, operations, read_preferences -from pymongo.cursor_shared import CursorType -from pymongo.errors import ConfigurationError, OperationFailure +from pymongo import MongoClient +from pymongo.errors import ConfigurationError from pymongo.hello import HelloCompat -from pymongo.helpers_shared import _SENSITIVE_COMMANDS from pymongo.lock import _create_lock -from pymongo.monitoring import ( - ConnectionCheckedInEvent, - ConnectionCheckedOutEvent, - ConnectionCheckOutFailedEvent, - ConnectionCheckOutStartedEvent, - ConnectionClosedEvent, - ConnectionCreatedEvent, - ConnectionReadyEvent, - PoolClearedEvent, - PoolClosedEvent, - PoolCreatedEvent, - PoolReadyEvent, -) from pymongo.operations import _Op -from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference from pymongo.server_selectors import any_server_selector, writable_server_selector -from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous.collection import ReturnDocument -from pymongo.synchronous.mongo_client import MongoClient from pymongo.synchronous.pool import _CancellationContext, _PoolGeneration -from pymongo.uri_parser import parse_uri -from pymongo.write_concern import WriteConcern -IMPOSSIBLE_WRITE_CONCERN = WriteConcern(w=50) +_IS_SYNC = True -class BaseListener: - def __init__(self): - self.events = [] - - def reset(self): - self.events = [] - - def add_event(self, event): - self.events.append(event) - - def event_count(self, event_type): - return len(self.events_by_type(event_type)) - - def events_by_type(self, event_type): - """Return the matching events by event class. - - event_type can be a single class or a tuple of classes. - """ - return self.matching(lambda e: isinstance(e, event_type)) - - def matching(self, matcher): - """Return the matching events.""" - return [event for event in self.events[:] if matcher(event)] - - def wait_for_event(self, event, count): - """Wait for a number of events to be published, or fail.""" - wait_until(lambda: self.event_count(event) >= count, f"find {count} {event} event(s)") - - async def async_wait_for_event(self, event, count): - """Wait for a number of events to be published, or fail.""" - await async_wait_until( - lambda: self.event_count(event) >= count, f"find {count} {event} event(s)" - ) - - -class CMAPListener(BaseListener, monitoring.ConnectionPoolListener): - def connection_created(self, event): - assert isinstance(event, ConnectionCreatedEvent) - self.add_event(event) - - def connection_ready(self, event): - assert isinstance(event, ConnectionReadyEvent) - self.add_event(event) - - def connection_closed(self, event): - assert isinstance(event, ConnectionClosedEvent) - self.add_event(event) - - def connection_check_out_started(self, event): - assert isinstance(event, ConnectionCheckOutStartedEvent) - self.add_event(event) - - def connection_check_out_failed(self, event): - assert isinstance(event, ConnectionCheckOutFailedEvent) - self.add_event(event) - - def connection_checked_out(self, event): - assert isinstance(event, ConnectionCheckedOutEvent) - self.add_event(event) - - def connection_checked_in(self, event): - assert isinstance(event, ConnectionCheckedInEvent) - self.add_event(event) - - def pool_created(self, event): - assert isinstance(event, PoolCreatedEvent) - self.add_event(event) - - def pool_ready(self, event): - assert isinstance(event, PoolReadyEvent) - self.add_event(event) - - def pool_cleared(self, event): - assert isinstance(event, PoolClearedEvent) - self.add_event(event) - - def pool_closed(self, event): - assert isinstance(event, PoolClosedEvent) - self.add_event(event) - - -class EventListener(BaseListener, monitoring.CommandListener): - def __init__(self): - super().__init__() - self.results = defaultdict(list) - - @property - def started_events(self) -> List[monitoring.CommandStartedEvent]: - return self.results["started"] - - @property - def succeeded_events(self) -> List[monitoring.CommandSucceededEvent]: - return self.results["succeeded"] - - @property - def failed_events(self) -> List[monitoring.CommandFailedEvent]: - return self.results["failed"] - - def started(self, event: monitoring.CommandStartedEvent) -> None: - self.started_events.append(event) - self.add_event(event) - - def succeeded(self, event: monitoring.CommandSucceededEvent) -> None: - self.succeeded_events.append(event) - self.add_event(event) - - def failed(self, event: monitoring.CommandFailedEvent) -> None: - self.failed_events.append(event) - self.add_event(event) - - def started_command_names(self) -> List[str]: - """Return list of command names started.""" - return [event.command_name for event in self.started_events] - - def reset(self) -> None: - """Reset the state of this listener.""" - self.results.clear() - super().reset() - - -class TopologyEventListener(monitoring.TopologyListener): - def __init__(self): - self.results = defaultdict(list) - - def closed(self, event): - self.results["closed"].append(event) - - def description_changed(self, event): - self.results["description_changed"].append(event) - - def opened(self, event): - self.results["opened"].append(event) - - def reset(self): - """Reset the state of this listener.""" - self.results.clear() - - -class AllowListEventListener(EventListener): - def __init__(self, *commands): - self.commands = set(commands) - super().__init__() - - def started(self, event): - if event.command_name in self.commands: - super().started(event) - - def succeeded(self, event): - if event.command_name in self.commands: - super().succeeded(event) - - def failed(self, event): - if event.command_name in self.commands: - super().failed(event) - - -class OvertCommandListener(EventListener): - """A CommandListener that ignores sensitive commands.""" - - ignore_list_collections = False - - def started(self, event): - if event.command_name.lower() not in _SENSITIVE_COMMANDS: - super().started(event) - - def succeeded(self, event): - if event.command_name.lower() not in _SENSITIVE_COMMANDS: - super().succeeded(event) - - def failed(self, event): - if event.command_name.lower() not in _SENSITIVE_COMMANDS: - super().failed(event) - - -class _ServerEventListener: - """Listens to all events.""" - - def __init__(self): - self.results = [] - - def opened(self, event): - self.results.append(event) - - def description_changed(self, event): - self.results.append(event) - - def closed(self, event): - self.results.append(event) - - def matching(self, matcher): - """Return the matching events.""" - results = self.results[:] - return [event for event in results if matcher(event)] - - def reset(self): - self.results = [] - - -class ServerEventListener(_ServerEventListener, monitoring.ServerListener): - """Listens to Server events.""" - - -class ServerAndTopologyEventListener( # type: ignore[misc] - ServerEventListener, monitoring.TopologyListener -): - """Listens to Server and Topology events.""" - - -class HeartbeatEventListener(BaseListener, monitoring.ServerHeartbeatListener): - """Listens to only server heartbeat events.""" - - def started(self, event): - self.add_event(event) - - def succeeded(self, event): - self.add_event(event) - - def failed(self, event): - self.add_event(event) - - -class HeartbeatEventsListListener(HeartbeatEventListener): - """Listens to only server heartbeat events and publishes them to a provided list.""" - - def __init__(self, events): - super().__init__() - self.event_list = events - - def started(self, event): - self.add_event(event) - self.event_list.append("serverHeartbeatStartedEvent") - - def succeeded(self, event): - self.add_event(event) - self.event_list.append("serverHeartbeatSucceededEvent") - - def failed(self, event): - self.add_event(event) - self.event_list.append("serverHeartbeatFailedEvent") - - -class MockConnection: - def __init__(self): - self.cancel_context = _CancellationContext() - self.more_to_come = False - self.id = random.randint(0, 100) - - def close_conn(self, reason): - pass - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - pass - - -class MockPool: - def __init__(self, address, options, handshake=True, client_id=None): - self.gen = _PoolGeneration() - self._lock = _create_lock() - self.opts = options - self.operation_count = 0 - self.conns = [] - - def stale_generation(self, gen, service_id): - return self.gen.stale(gen, service_id) - - def checkout(self, handler=None): - return MockConnection() - - def checkin(self, *args, **kwargs): - pass - - def _reset(self, service_id=None): - with self._lock: - self.gen.inc(service_id) - - def ready(self): - pass - - def reset(self, service_id=None, interrupt_connections=False): - self._reset() - - def reset_without_pause(self): - self._reset() - - def close(self): - self._reset() - - def update_is_writable(self, is_writable): - pass - - def remove_stale_sockets(self, *args, **kwargs): - pass - - -class ScenarioDict(dict): - """Dict that returns {} for any unknown key, recursively.""" - - def __init__(self, data): - def convert(v): - if isinstance(v, abc.Mapping): - return ScenarioDict(v) - if isinstance(v, (str, bytes)): - return v - if isinstance(v, abc.Sequence): - return [convert(item) for item in v] - return v - - dict.__init__(self, [(k, convert(v)) for k, v in data.items()]) +def get_pool(client): + """Get the standalone, primary, or mongos pool.""" + topology = client._get_topology() + server = topology._select_server(writable_server_selector, _Op.TEST) + return server.pool - def __getitem__(self, item): - try: - return dict.__getitem__(self, item) - except KeyError: - # Unlike a defaultdict, don't set the key, just return a dict. - return ScenarioDict({}) +def get_pools(client): + """Get all pools.""" + return [ + server.pool + for server in (client._get_topology()).select_servers(any_server_selector, _Op.TEST) + ] -class CompareType: - """Class that compares equal to any object of the given type(s).""" - def __init__(self, types): - self.types = types +def wait_until(predicate, success_description, timeout=10): + """Wait up to 10 seconds (by default) for predicate to be true. - def __eq__(self, other): - return isinstance(other, self.types) + E.g.: + wait_until(lambda: client.primary == ('a', 1), + 'connect to the primary') -class FunctionCallRecorder: - """Utility class to wrap a callable and record its invocations.""" + If the lambda-expression isn't true after 10 seconds, we raise + AssertionError("Didn't ever connect to the primary"). - def __init__(self, function): - self._function = function - self._call_list = [] + Returns the predicate's first true value. + """ + start = time.time() + interval = min(float(timeout) / 100, 0.1) + while True: + if iscoroutinefunction(predicate): + retval = predicate() + else: + retval = predicate() + if retval: + return retval - def __call__(self, *args, **kwargs): - self._call_list.append((args, kwargs)) - return self._function(*args, **kwargs) + if time.time() - start > timeout: + raise AssertionError("Didn't ever %s" % success_description) - def reset(self): - """Wipes the call list.""" - self._call_list = [] + time.sleep(interval) - def call_list(self): - """Returns a copy of the call list.""" - return self._call_list[:] - @property - def call_count(self): - """Returns the number of times the function has been called.""" - return len(self._call_list) +def is_mongos(client): + res = client.admin.command(HelloCompat.LEGACY_CMD) + return res.get("msg", "") == "isdbgrid" def ensure_all_connected(client: MongoClient) -> None: @@ -453,226 +111,17 @@ def discover(): return connected_host_list try: - wait_until(lambda: target_host_list == discover(), "connected to all hosts") - except AssertionError as exc: - raise AssertionError( - f"{exc}, {connected_host_list} != {target_host_list}, {client.topology_description}" - ) + def predicate(): + return target_host_list == discover() -async def async_ensure_all_connected(client: AsyncMongoClient) -> None: - """Ensure that the client's connection pool has socket connections to all - members of a replica set. Raises ConfigurationError when called with a - non-replica set client. - - Depending on the use-case, the caller may need to clear any event listeners - that are configured on the client. - """ - hello: dict = await client.admin.command(HelloCompat.LEGACY_CMD) - if "setName" not in hello: - raise ConfigurationError("cluster is not a replica set") - - target_host_list = set(hello["hosts"] + hello.get("passives", [])) - connected_host_list = {hello["me"]} - - # Run hello until we have connected to each host at least once. - async def discover(): - i = 0 - while i < 100 and connected_host_list != target_host_list: - hello: dict = await client.admin.command( - HelloCompat.LEGACY_CMD, read_preference=ReadPreference.SECONDARY - ) - connected_host_list.update([hello["me"]]) - i += 1 - return connected_host_list - - try: - - async def predicate(): - return target_host_list == await discover() - - await async_wait_until(predicate, "connected to all hosts") + wait_until(predicate, "connected to all hosts") except AssertionError as exc: raise AssertionError( f"{exc}, {connected_host_list} != {target_host_list}, {client.topology_description}" ) -def one(s): - """Get one element of a set""" - return next(iter(s)) - - -def oid_generated_on_process(oid): - """Makes a determination as to whether the given ObjectId was generated - by the current process, based on the 5-byte random number in the ObjectId. - """ - return ObjectId._random() == oid.binary[4:9] - - -def delay(sec): - return """function() { sleep(%f * 1000); return true; }""" % sec - - -def get_command_line(client): - command_line = client.admin.command("getCmdLineOpts") - assert command_line["ok"] == 1, "getCmdLineOpts() failed" - return command_line - - -def camel_to_snake(camel): - # Regex to convert CamelCase to snake_case. - snake = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel) - return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake).lower() - - -def camel_to_upper_camel(camel): - return camel[0].upper() + camel[1:] - - -def camel_to_snake_args(arguments): - for arg_name in list(arguments): - c2s = camel_to_snake(arg_name) - arguments[c2s] = arguments.pop(arg_name) - return arguments - - -def snake_to_camel(snake): - # Regex to convert snake_case to lowerCamelCase. - return re.sub(r"_([a-z])", lambda m: m.group(1).upper(), snake) - - -def parse_collection_options(opts): - if "readPreference" in opts: - opts["read_preference"] = parse_read_preference(opts.pop("readPreference")) - - if "writeConcern" in opts: - opts["write_concern"] = WriteConcern(**dict(opts.pop("writeConcern"))) - - if "readConcern" in opts: - opts["read_concern"] = ReadConcern(**dict(opts.pop("readConcern"))) - - if "timeoutMS" in opts: - opts["timeout"] = int(opts.pop("timeoutMS")) / 1000.0 - return opts - - -def server_started_with_option(client, cmdline_opt, config_opt): - """Check if the server was started with a particular option. - - :Parameters: - - `cmdline_opt`: The command line option (i.e. --nojournal) - - `config_opt`: The config file option (i.e. nojournal) - """ - command_line = get_command_line(client) - if "parsed" in command_line: - parsed = command_line["parsed"] - if config_opt in parsed: - return parsed[config_opt] - argv = command_line["argv"] - return cmdline_opt in argv - - -def server_started_with_auth(client): - try: - command_line = get_command_line(client) - except OperationFailure as e: - assert e.details is not None - msg = e.details.get("errmsg", "") - if e.code == 13 or "unauthorized" in msg or "login" in msg: - # Unauthorized. - return True - raise - - # MongoDB >= 2.0 - if "parsed" in command_line: - parsed = command_line["parsed"] - # MongoDB >= 2.6 - if "security" in parsed: - security = parsed["security"] - # >= rc3 - if "authorization" in security: - return security["authorization"] == "enabled" - # < rc3 - return security.get("auth", False) or bool(security.get("keyFile")) - return parsed.get("auth", False) or bool(parsed.get("keyFile")) - # Legacy - argv = command_line["argv"] - return "--auth" in argv or "--keyFile" in argv - - -def joinall(threads): - """Join threads with a 5-minute timeout, assert joins succeeded""" - for t in threads: - t.join(300) - assert not t.is_alive(), "Thread %s hung" % t - - -def wait_until(predicate, success_description, timeout=10): - """Wait up to 10 seconds (by default) for predicate to be true. - - E.g.: - - wait_until(lambda: client.primary == ('a', 1), - 'connect to the primary') - - If the lambda-expression isn't true after 10 seconds, we raise - AssertionError("Didn't ever connect to the primary"). - - Returns the predicate's first true value. - """ - start = time.time() - interval = min(float(timeout) / 100, 0.1) - while True: - retval = predicate() - if retval: - return retval - - if time.time() - start > timeout: - raise AssertionError("Didn't ever %s" % success_description) - - time.sleep(interval) - - -async def async_wait_until(predicate, success_description, timeout=10): - """Wait up to 10 seconds (by default) for predicate to be true. - - E.g.: - - wait_until(lambda: client.primary == ('a', 1), - 'connect to the primary') - - If the lambda-expression isn't true after 10 seconds, we raise - AssertionError("Didn't ever connect to the primary"). - - Returns the predicate's first true value. - """ - start = time.time() - interval = min(float(timeout) / 100, 0.1) - while True: - if iscoroutinefunction(predicate): - retval = await predicate() - else: - retval = predicate() - if retval: - return retval - - if time.time() - start > timeout: - raise AssertionError("Didn't ever %s" % success_description) - - await asyncio.sleep(interval) - - -def is_mongos(client): - res = client.admin.command(HelloCompat.LEGACY_CMD) - return res.get("msg", "") == "isdbgrid" - - -async def async_is_mongos(client): - res = await client.admin.command(HelloCompat.LEGACY_CMD) - return res.get("msg", "") == "isdbgrid" - - def assertRaisesExactly(cls, fn, *args, **kwargs): """ Unlike the standard assertRaises, this checks that a function raises a @@ -687,338 +136,75 @@ def assertRaisesExactly(cls, fn, *args, **kwargs): raise AssertionError("%s not raised" % cls) -async def asyncAssertRaisesExactly(cls, fn, *args, **kwargs): - """ - Unlike the standard assertRaises, this checks that a function raises a - specific class of exception, and not a subclass. E.g., check that - MongoClient() raises ConnectionFailure but not its subclass, AutoReconnect. - """ - try: - await fn(*args, **kwargs) - except Exception as e: - assert e.__class__ == cls, f"got {e.__class__.__name__}, expected {cls.__name__}" - else: - raise AssertionError("%s not raised" % cls) - - -@contextlib.contextmanager -def _ignore_deprecations(): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - yield - - -def ignore_deprecations(wrapped=None): - """A context manager or a decorator.""" - if wrapped: - if iscoroutinefunction(wrapped): - - @functools.wraps(wrapped) - async def wrapper(*args, **kwargs): - with _ignore_deprecations(): - return await wrapped(*args, **kwargs) - else: - - @functools.wraps(wrapped) - def wrapper(*args, **kwargs): - with _ignore_deprecations(): - return wrapped(*args, **kwargs) +def set_fail_point(client, command_args): + cmd = SON([("configureFailPoint", "failCommand")]) + cmd.update(command_args) + client.admin.command(cmd) - return wrapper +def joinall(tasks): + """Join threads with a 5-minute timeout, assert joins succeeded""" + if _IS_SYNC: + for t in tasks: + t.join(300) + assert not t.is_alive(), "Thread %s hung" % t else: - return _ignore_deprecations() - - -class DeprecationFilter: - def __init__(self, action="ignore"): - """Start filtering deprecations.""" - self.warn_context = warnings.catch_warnings() - self.warn_context.__enter__() - warnings.simplefilter(action, DeprecationWarning) - - def stop(self): - """Stop filtering deprecations.""" - self.warn_context.__exit__() # type: ignore - self.warn_context = None # type: ignore - - -def get_pool(client): - """Get the standalone, primary, or mongos pool.""" - topology = client._get_topology() - server = topology._select_server(writable_server_selector, _Op.TEST) - return server.pool - - -async def async_get_pool(client): - """Get the standalone, primary, or mongos pool.""" - topology = await client._get_topology() - server = await topology._select_server(writable_server_selector, _Op.TEST) - return server.pool - - -def get_pools(client): - """Get all pools.""" - return [ - server.pool - for server in client._get_topology().select_servers(any_server_selector, _Op.TEST) - ] - - -async def async_get_pools(client): - """Get all pools.""" - return [ - server.pool - async for server in await (await client._get_topology()).select_servers( - any_server_selector, _Op.TEST - ) - ] - - -# Constants for run_threads and lazy_client_trial. -NTRIALS = 5 -NTHREADS = 10 - - -def run_threads(collection, target): - """Run a target function in many threads. - - target is a function taking a Collection and an integer. - """ - threads = [] - for i in range(NTHREADS): - bound_target = partial(target, collection, i) - threads.append(threading.Thread(target=bound_target)) - - for t in threads: - t.start() - - for t in threads: - t.join(60) - assert not t.is_alive() - - -@contextlib.contextmanager -def frequent_thread_switches(): - """Make concurrency bugs more likely to manifest.""" - interval = sys.getswitchinterval() - sys.setswitchinterval(1e-6) - - try: - yield - finally: - sys.setswitchinterval(interval) - - -def lazy_client_trial(reset, target, test, get_client): - """Test concurrent operations on a lazily-connecting client. - - `reset` takes a collection and resets it for the next trial. - - `target` takes a lazily-connecting collection and an index from - 0 to NTHREADS, and performs some operation, e.g. an insert. - - `test` takes the lazily-connecting collection and asserts a - post-condition to prove `target` succeeded. - """ - collection = client_context.client.pymongo_test.test - - with frequent_thread_switches(): - for _i in range(NTRIALS): - reset(collection) - lazy_client = get_client() - lazy_collection = lazy_client.pymongo_test.test - run_threads(lazy_collection, target) - test(lazy_collection) - - -def gevent_monkey_patched(): - """Check if gevent's monkey patching is active.""" - try: - import socket - - import gevent.socket # type:ignore[import] - - return socket.socket is gevent.socket.socket - except ImportError: - return False - - -def eventlet_monkey_patched(): - """Check if eventlet's monkey patching is active.""" - import threading - - return threading.current_thread.__module__ == "eventlet.green.threading" - - -def is_greenthread_patched(): - return gevent_monkey_patched() or eventlet_monkey_patched() - - -class ExceptionCatchingThread(threading.Thread): - """A thread that stores any exception encountered from run().""" - - def __init__(self, *args, **kwargs): - self.exc = None - super().__init__(*args, **kwargs) + asyncio.wait([t.task for t in tasks if t is not None], timeout=300) - def run(self): - try: - super().run() - except BaseException as exc: - self.exc = exc - raise - - -def parse_read_preference(pref): - # Make first letter lowercase to match read_pref's modes. - mode_string = pref.get("mode", "primary") - mode_string = mode_string[:1].lower() + mode_string[1:] - mode = read_preferences.read_pref_mode_from_name(mode_string) - max_staleness = pref.get("maxStalenessSeconds", -1) - tag_sets = pref.get("tagSets") or pref.get("tag_sets") - return read_preferences.make_read_preference( - mode, tag_sets=tag_sets, max_staleness=max_staleness - ) - - -def server_name_to_type(name): - """Convert a ServerType name to the corresponding value. For SDAM tests.""" - # Special case, some tests in the spec include the PossiblePrimary - # type, but only single-threaded drivers need that type. We call - # possible primaries Unknown. - if name == "PossiblePrimary": - return SERVER_TYPE.Unknown - return getattr(SERVER_TYPE, name) - - -def cat_files(dest, *sources): - """Cat multiple files into dest.""" - with open(dest, "wb") as fdst: - for src in sources: - with open(src, "rb") as fsrc: - shutil.copyfileobj(fsrc, fdst) +class MockConnection: + def __init__(self): + self.cancel_context = _CancellationContext() + self.more_to_come = False + self.id = random.randint(0, 100) + self.server_connection_id = random.randint(0, 100) -@contextlib.contextmanager -def assertion_context(msg): - """A context manager that adds info to an assertion failure.""" - try: - yield - except AssertionError as exc: - raise AssertionError(f"{msg}: {exc}") + def close_conn(self, reason): + pass + def __enter__(self): + return self -def parse_spec_options(opts): - if "readPreference" in opts: - opts["read_preference"] = parse_read_preference(opts.pop("readPreference")) + def __exit__(self, exc_type, exc_val, exc_tb): + pass - if "writeConcern" in opts: - w_opts = opts.pop("writeConcern") - if "journal" in w_opts: - w_opts["j"] = w_opts.pop("journal") - if "wtimeoutMS" in w_opts: - w_opts["wtimeout"] = w_opts.pop("wtimeoutMS") - opts["write_concern"] = WriteConcern(**dict(w_opts)) - if "readConcern" in opts: - opts["read_concern"] = ReadConcern(**dict(opts.pop("readConcern"))) +class MockPool: + def __init__(self, address, options, handshake=True, client_id=None): + self.gen = _PoolGeneration() + self._lock = _create_lock() + self.opts = options + self.operation_count = 0 + self.conns = [] - if "timeoutMS" in opts: - assert isinstance(opts["timeoutMS"], int) - opts["timeout"] = int(opts.pop("timeoutMS")) / 1000.0 + def stale_generation(self, gen, service_id): + return self.gen.stale(gen, service_id) - if "maxTimeMS" in opts: - opts["max_time_ms"] = opts.pop("maxTimeMS") + @contextlib.contextmanager + def checkout(self, handler=None): + yield MockConnection() - if "maxCommitTimeMS" in opts: - opts["max_commit_time_ms"] = opts.pop("maxCommitTimeMS") + def checkin(self, *args, **kwargs): + pass - return dict(opts) + def _reset(self, service_id=None): + with self._lock: + self.gen.inc(service_id) + def ready(self): + pass -def prepare_spec_arguments(spec, arguments, opname, entity_map, with_txn_callback): - for arg_name in list(arguments): - c2s = camel_to_snake(arg_name) - # Named "key" instead not fieldName. - if arg_name == "fieldName": - arguments["key"] = arguments.pop(arg_name) - # Aggregate uses "batchSize", while find uses batch_size. - elif (arg_name == "batchSize" or arg_name == "allowDiskUse") and opname == "aggregate": - continue - elif arg_name == "timeoutMode": - raise unittest.SkipTest("PyMongo does not support timeoutMode") - # Requires boolean returnDocument. - elif arg_name == "returnDocument": - arguments[c2s] = getattr(ReturnDocument, arguments.pop(arg_name).upper()) - elif "bulk_write" in opname and (c2s == "requests" or c2s == "models"): - # Parse each request into a bulk write model. - requests = [] - for request in arguments[c2s]: - if "name" in request: - # CRUD v2 format - bulk_model = camel_to_upper_camel(request["name"]) - bulk_class = getattr(operations, bulk_model) - bulk_arguments = camel_to_snake_args(request["arguments"]) - else: - # Unified test format - bulk_model, spec = next(iter(request.items())) - bulk_class = getattr(operations, camel_to_upper_camel(bulk_model)) - bulk_arguments = camel_to_snake_args(spec) - requests.append(bulk_class(**dict(bulk_arguments))) - arguments[c2s] = requests - elif arg_name == "session": - arguments["session"] = entity_map[arguments["session"]] - elif opname == "open_download_stream" and arg_name == "id": - arguments["file_id"] = arguments.pop(arg_name) - elif opname not in ("find", "find_one") and c2s == "max_time_ms": - # find is the only method that accepts snake_case max_time_ms. - # All other methods take kwargs which must use the server's - # camelCase maxTimeMS. See PYTHON-1855. - arguments["maxTimeMS"] = arguments.pop("max_time_ms") - elif opname == "with_transaction" and arg_name == "callback": - if "operations" in arguments[arg_name]: - # CRUD v2 format - callback_ops = arguments[arg_name]["operations"] - else: - # Unified test format - callback_ops = arguments[arg_name] - arguments["callback"] = lambda _: with_txn_callback(copy.deepcopy(callback_ops)) - elif opname == "drop_collection" and arg_name == "collection": - arguments["name_or_collection"] = arguments.pop(arg_name) - elif opname == "create_collection": - if arg_name == "collection": - arguments["name"] = arguments.pop(arg_name) - arguments["check_exists"] = False - # Any other arguments to create_collection are passed through - # **kwargs. - elif opname == "create_index" and arg_name == "keys": - arguments["keys"] = list(arguments.pop(arg_name).items()) - elif opname == "drop_index" and arg_name == "name": - arguments["index_or_name"] = arguments.pop(arg_name) - elif opname == "rename" and arg_name == "to": - arguments["new_name"] = arguments.pop(arg_name) - elif opname == "rename" and arg_name == "dropTarget": - arguments["dropTarget"] = arguments.pop(arg_name) - elif arg_name == "cursorType": - cursor_type = arguments.pop(arg_name) - if cursor_type == "tailable": - arguments["cursor_type"] = CursorType.TAILABLE - elif cursor_type == "tailableAwait": - arguments["cursor_type"] = CursorType.TAILABLE - else: - raise AssertionError(f"Unsupported cursorType: {cursor_type}") - else: - arguments[c2s] = arguments.pop(arg_name) + def reset(self, service_id=None, interrupt_connections=False): + self._reset() + def reset_without_pause(self): + self._reset() -def set_fail_point(client, command_args): - cmd = SON([("configureFailPoint", "failCommand")]) - cmd.update(command_args) - client.admin.command(cmd) + def close(self): + self._reset() + def update_is_writable(self, is_writable): + pass -async def async_set_fail_point(client, command_args): - cmd = SON([("configureFailPoint", "failCommand")]) - cmd.update(command_args) - await client.admin.command(cmd) + def remove_stale_sockets(self, *args, **kwargs): + pass diff --git a/test/utils_selection_tests.py b/test/utils_selection_tests.py index 2d21888e27..2772f06070 100644 --- a/test/utils_selection_tests.py +++ b/test/utils_selection_tests.py @@ -18,96 +18,29 @@ import datetime import os import sys +from test import PyMongoTestCase +from test.utils import MockPool sys.path[0:0] = [""] from test import unittest from test.pymongo_mocks import DummyMonitor -from test.utils import MockPool, parse_read_preference +from test.utils_selection_tests_shared import ( + get_addresses, + get_topology_type_name, + make_server_description, +) +from test.utils_shared import parse_read_preference from bson import json_util -from pymongo.common import HEARTBEAT_FREQUENCY, MIN_SUPPORTED_WIRE_VERSION, clean_node +from pymongo.common import HEARTBEAT_FREQUENCY from pymongo.errors import AutoReconnect, ConfigurationError -from pymongo.hello import Hello, HelloCompat from pymongo.operations import _Op -from pymongo.server_description import ServerDescription from pymongo.server_selectors import writable_server_selector from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology - -def get_addresses(server_list): - seeds = [] - hosts = [] - for server in server_list: - seeds.append(clean_node(server["address"])) - hosts.append(server["address"]) - return seeds, hosts - - -def make_last_write_date(server): - epoch = datetime.datetime.fromtimestamp(0, tz=datetime.timezone.utc).replace(tzinfo=None) - millis = server.get("lastWrite", {}).get("lastWriteDate") - if millis: - diff = ((millis % 1000) + 1000) % 1000 - seconds = (millis - diff) / 1000 - micros = diff * 1000 - return epoch + datetime.timedelta(seconds=seconds, microseconds=micros) - else: - # "Unknown" server. - return epoch - - -def make_server_description(server, hosts): - """Make a ServerDescription from server info in a JSON test.""" - server_type = server["type"] - if server_type in ("Unknown", "PossiblePrimary"): - return ServerDescription(clean_node(server["address"]), Hello({})) - - hello_response = {"ok": True, "hosts": hosts} - if server_type not in ("Standalone", "Mongos", "RSGhost"): - hello_response["setName"] = "rs" - - if server_type == "RSPrimary": - hello_response[HelloCompat.LEGACY_CMD] = True - elif server_type == "RSSecondary": - hello_response["secondary"] = True - elif server_type == "Mongos": - hello_response["msg"] = "isdbgrid" - elif server_type == "RSGhost": - hello_response["isreplicaset"] = True - elif server_type == "RSArbiter": - hello_response["arbiterOnly"] = True - - hello_response["lastWrite"] = {"lastWriteDate": make_last_write_date(server)} - - for field in "maxWireVersion", "tags", "idleWritePeriodMillis": - if field in server: - hello_response[field] = server[field] - - hello_response.setdefault("maxWireVersion", MIN_SUPPORTED_WIRE_VERSION) - - # Sets _last_update_time to now. - sd = ServerDescription( - clean_node(server["address"]), - Hello(hello_response), - round_trip_time=server["avg_rtt_ms"] / 1000.0, - ) - - if "lastUpdateTime" in server: - sd._last_update_time = server["lastUpdateTime"] / 1000.0 # ms to sec. - - return sd - - -def get_topology_type_name(scenario_def): - td = scenario_def["topology_description"] - name = td["type"] - if name == "Unknown": - # PyMongo never starts a topology in type Unknown. - return "Sharded" if len(td["servers"]) > 1 else "Single" - else: - return name +_IS_SYNC = True def get_topology_settings_dict(**kwargs): @@ -244,7 +177,7 @@ def run_scenario(self): def create_selection_tests(test_dir): - class TestAllScenarios(unittest.TestCase): + class TestAllScenarios(PyMongoTestCase): pass for dirpath, _, filenames in os.walk(test_dir): diff --git a/test/utils_selection_tests_shared.py b/test/utils_selection_tests_shared.py new file mode 100644 index 0000000000..dbaed1034f --- /dev/null +++ b/test/utils_selection_tests_shared.py @@ -0,0 +1,100 @@ +# Copyright 2015-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for testing Server Selection and Max Staleness.""" +from __future__ import annotations + +import datetime +import os +import sys + +sys.path[0:0] = [""] + +from pymongo.common import MIN_SUPPORTED_WIRE_VERSION, clean_node +from pymongo.hello import Hello, HelloCompat +from pymongo.server_description import ServerDescription + + +def get_addresses(server_list): + seeds = [] + hosts = [] + for server in server_list: + seeds.append(clean_node(server["address"])) + hosts.append(server["address"]) + return seeds, hosts + + +def make_last_write_date(server): + epoch = datetime.datetime.fromtimestamp(0, tz=datetime.timezone.utc).replace(tzinfo=None) + millis = server.get("lastWrite", {}).get("lastWriteDate") + if millis: + diff = ((millis % 1000) + 1000) % 1000 + seconds = (millis - diff) / 1000 + micros = diff * 1000 + return epoch + datetime.timedelta(seconds=seconds, microseconds=micros) + else: + # "Unknown" server. + return epoch + + +def make_server_description(server, hosts): + """Make a ServerDescription from server info in a JSON test.""" + server_type = server["type"] + if server_type in ("Unknown", "PossiblePrimary"): + return ServerDescription(clean_node(server["address"]), Hello({})) + + hello_response = {"ok": True, "hosts": hosts} + if server_type not in ("Standalone", "Mongos", "RSGhost"): + hello_response["setName"] = "rs" + + if server_type == "RSPrimary": + hello_response[HelloCompat.LEGACY_CMD] = True + elif server_type == "RSSecondary": + hello_response["secondary"] = True + elif server_type == "Mongos": + hello_response["msg"] = "isdbgrid" + elif server_type == "RSGhost": + hello_response["isreplicaset"] = True + elif server_type == "RSArbiter": + hello_response["arbiterOnly"] = True + + hello_response["lastWrite"] = {"lastWriteDate": make_last_write_date(server)} + + for field in "maxWireVersion", "tags", "idleWritePeriodMillis": + if field in server: + hello_response[field] = server[field] + + hello_response.setdefault("maxWireVersion", MIN_SUPPORTED_WIRE_VERSION) + + # Sets _last_update_time to now. + sd = ServerDescription( + clean_node(server["address"]), + Hello(hello_response), + round_trip_time=server["avg_rtt_ms"] / 1000.0, + ) + + if "lastUpdateTime" in server: + sd._last_update_time = server["lastUpdateTime"] / 1000.0 # ms to sec. + + return sd + + +def get_topology_type_name(scenario_def): + td = scenario_def["topology_description"] + name = td["type"] + if name == "Unknown": + # PyMongo never starts a topology in type Unknown. + return "Sharded" if len(td["servers"]) > 1 else "Single" + else: + return name diff --git a/test/utils_shared.py b/test/utils_shared.py new file mode 100644 index 0000000000..e0789b6632 --- /dev/null +++ b/test/utils_shared.py @@ -0,0 +1,709 @@ +# Copyright 2012-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared utilities for testing pymongo""" +from __future__ import annotations + +import asyncio +import contextlib +import copy +import functools +import random +import re +import shutil +import sys +import threading +import unittest +import warnings +from asyncio import iscoroutinefunction +from collections import abc, defaultdict +from functools import partial +from test import client_context +from test.asynchronous.utils import async_wait_until +from test.utils import wait_until +from typing import List + +from bson.objectid import ObjectId +from pymongo import monitoring, operations, read_preferences +from pymongo.cursor_shared import CursorType +from pymongo.errors import OperationFailure +from pymongo.helpers_shared import _SENSITIVE_COMMANDS +from pymongo.lock import _async_create_lock, _create_lock +from pymongo.monitoring import ( + ConnectionCheckedInEvent, + ConnectionCheckedOutEvent, + ConnectionCheckOutFailedEvent, + ConnectionCheckOutStartedEvent, + ConnectionClosedEvent, + ConnectionCreatedEvent, + ConnectionReadyEvent, + PoolClearedEvent, + PoolClosedEvent, + PoolCreatedEvent, + PoolReadyEvent, +) +from pymongo.read_concern import ReadConcern +from pymongo.server_type import SERVER_TYPE +from pymongo.synchronous.collection import ReturnDocument +from pymongo.synchronous.pool import _CancellationContext, _PoolGeneration +from pymongo.write_concern import WriteConcern + +IMPOSSIBLE_WRITE_CONCERN = WriteConcern(w=50) + + +class BaseListener: + def __init__(self): + self.events = [] + + def reset(self): + self.events = [] + + def add_event(self, event): + self.events.append(event) + + def event_count(self, event_type): + return len(self.events_by_type(event_type)) + + def events_by_type(self, event_type): + """Return the matching events by event class. + + event_type can be a single class or a tuple of classes. + """ + return self.matching(lambda e: isinstance(e, event_type)) + + def matching(self, matcher): + """Return the matching events.""" + return [event for event in self.events[:] if matcher(event)] + + def wait_for_event(self, event, count): + """Wait for a number of events to be published, or fail.""" + wait_until(lambda: self.event_count(event) >= count, f"find {count} {event} event(s)") + + async def async_wait_for_event(self, event, count): + """Wait for a number of events to be published, or fail.""" + await async_wait_until( + lambda: self.event_count(event) >= count, f"find {count} {event} event(s)" + ) + + +class CMAPListener(BaseListener, monitoring.ConnectionPoolListener): + def connection_created(self, event): + assert isinstance(event, ConnectionCreatedEvent) + self.add_event(event) + + def connection_ready(self, event): + assert isinstance(event, ConnectionReadyEvent) + self.add_event(event) + + def connection_closed(self, event): + assert isinstance(event, ConnectionClosedEvent) + self.add_event(event) + + def connection_check_out_started(self, event): + assert isinstance(event, ConnectionCheckOutStartedEvent) + self.add_event(event) + + def connection_check_out_failed(self, event): + assert isinstance(event, ConnectionCheckOutFailedEvent) + self.add_event(event) + + def connection_checked_out(self, event): + assert isinstance(event, ConnectionCheckedOutEvent) + self.add_event(event) + + def connection_checked_in(self, event): + assert isinstance(event, ConnectionCheckedInEvent) + self.add_event(event) + + def pool_created(self, event): + assert isinstance(event, PoolCreatedEvent) + self.add_event(event) + + def pool_ready(self, event): + assert isinstance(event, PoolReadyEvent) + self.add_event(event) + + def pool_cleared(self, event): + assert isinstance(event, PoolClearedEvent) + self.add_event(event) + + def pool_closed(self, event): + assert isinstance(event, PoolClosedEvent) + self.add_event(event) + + +class EventListener(BaseListener, monitoring.CommandListener): + def __init__(self): + super().__init__() + self.results = defaultdict(list) + + @property + def started_events(self) -> List[monitoring.CommandStartedEvent]: + return self.results["started"] + + @property + def succeeded_events(self) -> List[monitoring.CommandSucceededEvent]: + return self.results["succeeded"] + + @property + def failed_events(self) -> List[monitoring.CommandFailedEvent]: + return self.results["failed"] + + def started(self, event: monitoring.CommandStartedEvent) -> None: + self.started_events.append(event) + self.add_event(event) + + def succeeded(self, event: monitoring.CommandSucceededEvent) -> None: + self.succeeded_events.append(event) + self.add_event(event) + + def failed(self, event: monitoring.CommandFailedEvent) -> None: + self.failed_events.append(event) + self.add_event(event) + + def started_command_names(self) -> List[str]: + """Return list of command names started.""" + return [event.command_name for event in self.started_events] + + def reset(self) -> None: + """Reset the state of this listener.""" + self.results.clear() + super().reset() + + +class TopologyEventListener(monitoring.TopologyListener): + def __init__(self): + self.results = defaultdict(list) + + def closed(self, event): + self.results["closed"].append(event) + + def description_changed(self, event): + self.results["description_changed"].append(event) + + def opened(self, event): + self.results["opened"].append(event) + + def reset(self): + """Reset the state of this listener.""" + self.results.clear() + + +class AllowListEventListener(EventListener): + def __init__(self, *commands): + self.commands = set(commands) + super().__init__() + + def started(self, event): + if event.command_name in self.commands: + super().started(event) + + def succeeded(self, event): + if event.command_name in self.commands: + super().succeeded(event) + + def failed(self, event): + if event.command_name in self.commands: + super().failed(event) + + +class OvertCommandListener(EventListener): + """A CommandListener that ignores sensitive commands.""" + + ignore_list_collections = False + + def started(self, event): + if event.command_name.lower() not in _SENSITIVE_COMMANDS: + super().started(event) + + def succeeded(self, event): + if event.command_name.lower() not in _SENSITIVE_COMMANDS: + super().succeeded(event) + + def failed(self, event): + if event.command_name.lower() not in _SENSITIVE_COMMANDS: + super().failed(event) + + +class _ServerEventListener: + """Listens to all events.""" + + def __init__(self): + self.results = [] + + def opened(self, event): + self.results.append(event) + + def description_changed(self, event): + self.results.append(event) + + def closed(self, event): + self.results.append(event) + + def matching(self, matcher): + """Return the matching events.""" + results = self.results[:] + return [event for event in results if matcher(event)] + + def reset(self): + self.results = [] + + +class ServerEventListener(_ServerEventListener, monitoring.ServerListener): + """Listens to Server events.""" + + +class ServerAndTopologyEventListener( # type: ignore[misc] + ServerEventListener, monitoring.TopologyListener +): + """Listens to Server and Topology events.""" + + +class HeartbeatEventListener(BaseListener, monitoring.ServerHeartbeatListener): + """Listens to only server heartbeat events.""" + + def started(self, event): + self.add_event(event) + + def succeeded(self, event): + self.add_event(event) + + def failed(self, event): + self.add_event(event) + + +class HeartbeatEventsListListener(HeartbeatEventListener): + """Listens to only server heartbeat events and publishes them to a provided list.""" + + def __init__(self, events): + super().__init__() + self.event_list = events + + def started(self, event): + self.add_event(event) + self.event_list.append("serverHeartbeatStartedEvent") + + def succeeded(self, event): + self.add_event(event) + self.event_list.append("serverHeartbeatSucceededEvent") + + def failed(self, event): + self.add_event(event) + self.event_list.append("serverHeartbeatFailedEvent") + + +class ScenarioDict(dict): + """Dict that returns {} for any unknown key, recursively.""" + + def __init__(self, data): + def convert(v): + if isinstance(v, abc.Mapping): + return ScenarioDict(v) + if isinstance(v, (str, bytes)): + return v + if isinstance(v, abc.Sequence): + return [convert(item) for item in v] + return v + + dict.__init__(self, [(k, convert(v)) for k, v in data.items()]) + + def __getitem__(self, item): + try: + return dict.__getitem__(self, item) + except KeyError: + # Unlike a defaultdict, don't set the key, just return a dict. + return ScenarioDict({}) + + +class CompareType: + """Class that compares equal to any object of the given type(s).""" + + def __init__(self, types): + self.types = types + + def __eq__(self, other): + return isinstance(other, self.types) + + +class FunctionCallRecorder: + """Utility class to wrap a callable and record its invocations.""" + + def __init__(self, function): + self._function = function + self._call_list = [] + + def __call__(self, *args, **kwargs): + self._call_list.append((args, kwargs)) + if iscoroutinefunction(self._function): + return self._function(*args, **kwargs) + else: + return self._function(*args, **kwargs) + + def reset(self): + """Wipes the call list.""" + self._call_list = [] + + def call_list(self): + """Returns a copy of the call list.""" + return self._call_list[:] + + @property + def call_count(self): + """Returns the number of times the function has been called.""" + return len(self._call_list) + + +def one(s): + """Get one element of a set""" + return next(iter(s)) + + +def oid_generated_on_process(oid): + """Makes a determination as to whether the given ObjectId was generated + by the current process, based on the 5-byte random number in the ObjectId. + """ + return ObjectId._random() == oid.binary[4:9] + + +def delay(sec): + return """function() { sleep(%f * 1000); return true; }""" % sec + + +def camel_to_snake(camel): + # Regex to convert CamelCase to snake_case. + snake = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake).lower() + + +def camel_to_upper_camel(camel): + return camel[0].upper() + camel[1:] + + +def camel_to_snake_args(arguments): + for arg_name in list(arguments): + c2s = camel_to_snake(arg_name) + arguments[c2s] = arguments.pop(arg_name) + return arguments + + +def snake_to_camel(snake): + # Regex to convert snake_case to lowerCamelCase. + return re.sub(r"_([a-z])", lambda m: m.group(1).upper(), snake) + + +def parse_collection_options(opts): + if "readPreference" in opts: + opts["read_preference"] = parse_read_preference(opts.pop("readPreference")) + + if "writeConcern" in opts: + opts["write_concern"] = WriteConcern(**dict(opts.pop("writeConcern"))) + + if "readConcern" in opts: + opts["read_concern"] = ReadConcern(**dict(opts.pop("readConcern"))) + + if "timeoutMS" in opts: + opts["timeout"] = int(opts.pop("timeoutMS")) / 1000.0 + return opts + + +@contextlib.contextmanager +def _ignore_deprecations(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + yield + + +def ignore_deprecations(wrapped=None): + """A context manager or a decorator.""" + if wrapped: + if iscoroutinefunction(wrapped): + + @functools.wraps(wrapped) + async def wrapper(*args, **kwargs): + with _ignore_deprecations(): + return await wrapped(*args, **kwargs) + else: + + @functools.wraps(wrapped) + def wrapper(*args, **kwargs): + with _ignore_deprecations(): + return wrapped(*args, **kwargs) + + return wrapper + + else: + return _ignore_deprecations() + + +class DeprecationFilter: + def __init__(self, action="ignore"): + """Start filtering deprecations.""" + self.warn_context = warnings.catch_warnings() + self.warn_context.__enter__() + warnings.simplefilter(action, DeprecationWarning) + + def stop(self): + """Stop filtering deprecations.""" + self.warn_context.__exit__() # type: ignore + self.warn_context = None # type: ignore + + +# Constants for run_threads and lazy_client_trial. +NTRIALS = 5 +NTHREADS = 10 + + +def run_threads(collection, target): + """Run a target function in many threads. + + target is a function taking a Collection and an integer. + """ + threads = [] + for i in range(NTHREADS): + bound_target = partial(target, collection, i) + threads.append(threading.Thread(target=bound_target)) + + for t in threads: + t.start() + + for t in threads: + t.join(60) + assert not t.is_alive() + + +@contextlib.contextmanager +def frequent_thread_switches(): + """Make concurrency bugs more likely to manifest.""" + interval = sys.getswitchinterval() + sys.setswitchinterval(1e-6) + + try: + yield + finally: + sys.setswitchinterval(interval) + + +def lazy_client_trial(reset, target, test, get_client): + """Test concurrent operations on a lazily-connecting client. + + `reset` takes a collection and resets it for the next trial. + + `target` takes a lazily-connecting collection and an index from + 0 to NTHREADS, and performs some operation, e.g. an insert. + + `test` takes the lazily-connecting collection and asserts a + post-condition to prove `target` succeeded. + """ + collection = client_context.client.pymongo_test.test + + with frequent_thread_switches(): + for _i in range(NTRIALS): + reset(collection) + lazy_client = get_client() + lazy_collection = lazy_client.pymongo_test.test + run_threads(lazy_collection, target) + test(lazy_collection) + + +def gevent_monkey_patched(): + """Check if gevent's monkey patching is active.""" + try: + import socket + + import gevent.socket # type:ignore[import] + + return socket.socket is gevent.socket.socket + except ImportError: + return False + + +def eventlet_monkey_patched(): + """Check if eventlet's monkey patching is active.""" + import threading + + return threading.current_thread.__module__ == "eventlet.green.threading" + + +def is_greenthread_patched(): + return gevent_monkey_patched() or eventlet_monkey_patched() + + +def parse_read_preference(pref): + # Make first letter lowercase to match read_pref's modes. + mode_string = pref.get("mode", "primary") + mode_string = mode_string[:1].lower() + mode_string[1:] + mode = read_preferences.read_pref_mode_from_name(mode_string) + max_staleness = pref.get("maxStalenessSeconds", -1) + tag_sets = pref.get("tagSets") or pref.get("tag_sets") + return read_preferences.make_read_preference( + mode, tag_sets=tag_sets, max_staleness=max_staleness + ) + + +def server_name_to_type(name): + """Convert a ServerType name to the corresponding value. For SDAM tests.""" + # Special case, some tests in the spec include the PossiblePrimary + # type, but only single-threaded drivers need that type. We call + # possible primaries Unknown. + if name == "PossiblePrimary": + return SERVER_TYPE.Unknown + return getattr(SERVER_TYPE, name) + + +def cat_files(dest, *sources): + """Cat multiple files into dest.""" + with open(dest, "wb") as fdst: + for src in sources: + with open(src, "rb") as fsrc: + shutil.copyfileobj(fsrc, fdst) + + +@contextlib.contextmanager +def assertion_context(msg): + """A context manager that adds info to an assertion failure.""" + try: + yield + except AssertionError as exc: + raise AssertionError(f"{msg}: {exc}") + + +def parse_spec_options(opts): + if "readPreference" in opts: + opts["read_preference"] = parse_read_preference(opts.pop("readPreference")) + + if "writeConcern" in opts: + w_opts = opts.pop("writeConcern") + if "journal" in w_opts: + w_opts["j"] = w_opts.pop("journal") + if "wtimeoutMS" in w_opts: + w_opts["wtimeout"] = w_opts.pop("wtimeoutMS") + opts["write_concern"] = WriteConcern(**dict(w_opts)) + + if "readConcern" in opts: + opts["read_concern"] = ReadConcern(**dict(opts.pop("readConcern"))) + + if "timeoutMS" in opts: + assert isinstance(opts["timeoutMS"], int) + opts["timeout"] = int(opts.pop("timeoutMS")) / 1000.0 + + if "maxTimeMS" in opts: + opts["max_time_ms"] = opts.pop("maxTimeMS") + + if "maxCommitTimeMS" in opts: + opts["max_commit_time_ms"] = opts.pop("maxCommitTimeMS") + + return dict(opts) + + +def prepare_spec_arguments(spec, arguments, opname, entity_map, with_txn_callback): + for arg_name in list(arguments): + c2s = camel_to_snake(arg_name) + # Named "key" instead not fieldName. + if arg_name == "fieldName": + arguments["key"] = arguments.pop(arg_name) + # Aggregate uses "batchSize", while find uses batch_size. + elif (arg_name == "batchSize" or arg_name == "allowDiskUse") and opname == "aggregate": + continue + elif arg_name == "bypassDocumentValidation" and ( + opname == "aggregate" or "find_one_and" in opname + ): + continue + elif arg_name == "timeoutMode": + raise unittest.SkipTest("PyMongo does not support timeoutMode") + # Requires boolean returnDocument. + elif arg_name == "returnDocument": + arguments[c2s] = getattr(ReturnDocument, arguments.pop(arg_name).upper()) + elif "bulk_write" in opname and (c2s == "requests" or c2s == "models"): + # Parse each request into a bulk write model. + requests = [] + for request in arguments[c2s]: + if "name" in request: + # CRUD v2 format + bulk_model = camel_to_upper_camel(request["name"]) + bulk_class = getattr(operations, bulk_model) + bulk_arguments = camel_to_snake_args(request["arguments"]) + else: + # Unified test format + bulk_model, spec = next(iter(request.items())) + bulk_class = getattr(operations, camel_to_upper_camel(bulk_model)) + bulk_arguments = camel_to_snake_args(spec) + requests.append(bulk_class(**dict(bulk_arguments))) + arguments[c2s] = requests + elif arg_name == "session": + arguments["session"] = entity_map[arguments["session"]] + elif opname == "open_download_stream" and arg_name == "id": + arguments["file_id"] = arguments.pop(arg_name) + elif opname not in ("find", "find_one") and c2s == "max_time_ms": + # find is the only method that accepts snake_case max_time_ms. + # All other methods take kwargs which must use the server's + # camelCase maxTimeMS. See PYTHON-1855. + arguments["maxTimeMS"] = arguments.pop("max_time_ms") + elif opname == "with_transaction" and arg_name == "callback": + if "operations" in arguments[arg_name]: + # CRUD v2 format + callback_ops = arguments[arg_name]["operations"] + else: + # Unified test format + callback_ops = arguments[arg_name] + arguments["callback"] = lambda _: with_txn_callback(copy.deepcopy(callback_ops)) + elif opname == "drop_collection" and arg_name == "collection": + arguments["name_or_collection"] = arguments.pop(arg_name) + elif opname == "create_collection": + if arg_name == "collection": + arguments["name"] = arguments.pop(arg_name) + arguments["check_exists"] = False + # Any other arguments to create_collection are passed through + # **kwargs. + elif opname == "create_index" and arg_name == "keys": + arguments["keys"] = list(arguments.pop(arg_name).items()) + elif opname == "drop_index" and arg_name == "name": + arguments["index_or_name"] = arguments.pop(arg_name) + elif opname == "rename" and arg_name == "to": + arguments["new_name"] = arguments.pop(arg_name) + elif opname == "rename" and arg_name == "dropTarget": + arguments["dropTarget"] = arguments.pop(arg_name) + elif arg_name == "cursorType": + cursor_type = arguments.pop(arg_name) + if cursor_type == "tailable": + arguments["cursor_type"] = CursorType.TAILABLE + elif cursor_type == "tailableAwait": + arguments["cursor_type"] = CursorType.TAILABLE + else: + raise AssertionError(f"Unsupported cursorType: {cursor_type}") + else: + arguments[c2s] = arguments.pop(arg_name) + + +def create_async_event(): + return asyncio.Event() + + +def create_event(): + return threading.Event() + + +def async_create_barrier(n_tasks: int): + return asyncio.Barrier(n_tasks) + + +def create_barrier(n_tasks: int, timeout: float | None = None): + return threading.Barrier(n_tasks, timeout=timeout) + + +async def async_barrier_wait(barrier, timeout: float | None = None): + await asyncio.wait_for(barrier.wait(), timeout=timeout) + + +def barrier_wait(barrier, timeout: float | None = None): + barrier.wait(timeout=timeout) diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 4508502cd0..580e7cc120 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -18,12 +18,13 @@ import asyncio import functools import os -import threading +import time import unittest from asyncio import iscoroutinefunction from collections import abc from test import IntegrationTest, client_context, client_knobs -from test.utils import ( +from test.helpers import ConcurrentRunner +from test.utils_shared import ( CMAPListener, CompareType, EventListener, @@ -44,6 +45,7 @@ from gridfs import GridFSBucket from gridfs.synchronous.grid_file import GridFSBucket from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError +from pymongo.lock import _cond_wait, _create_condition, _create_lock from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference from pymongo.results import BulkWriteResult, _WriteResult @@ -55,15 +57,13 @@ _IS_SYNC = True -class SpecRunnerThread(threading.Thread): +class SpecRunnerThread(ConcurrentRunner): def __init__(self, name): - super().__init__() - self.name = name + super().__init__(name=name) self.exc = None self.daemon = True - self.cond = threading.Condition() + self.cond = _create_condition(_create_lock()) self.ops = [] - self.stopped = False def schedule(self, work): self.ops.append(work) @@ -79,7 +79,7 @@ def run(self): while not self.stopped or self.ops: if not self.ops: with self.cond: - self.cond.wait(10) + _cond_wait(self.cond, 10) if self.ops: try: work = self.ops.pop(0) @@ -265,15 +265,10 @@ def setUp(self) -> None: def tearDown(self) -> None: self.knobs.disable() - def _set_fail_point(self, client, command_args): - cmd = SON([("configureFailPoint", "failCommand")]) - cmd.update(command_args) - client.admin.command(cmd) - def set_fail_point(self, command_args): clients = self.mongos_clients if self.mongos_clients else [self.client] for client in clients: - self._set_fail_point(client, command_args) + self.configure_fail_point(client, command_args) def targeted_fail_point(self, session, fail_point): """Run the targetedFailPoint test operation. @@ -282,7 +277,7 @@ def targeted_fail_point(self, session, fail_point): """ clients = {c.address: c for c in self.mongos_clients} client = clients[session._pinned_address] - self._set_fail_point(client, fail_point) + self.configure_fail_point(client, fail_point) self.addCleanup(self.set_fail_point, {"mode": "off"}) def assert_session_pinned(self, session): @@ -320,6 +315,10 @@ def assert_index_not_exists(self, database, collection, index): coll = self.client[database][collection] self.assertNotIn(index, [doc["name"] for doc in coll.list_indexes()]) + def wait(self, ms): + """Run the "wait" test operation.""" + time.sleep(ms / 1000.0) + def assertErrorLabelsContain(self, exc, expected_labels): labels = [l for l in expected_labels if exc.has_error_label(l)] self.assertEqual(labels, expected_labels) diff --git a/test/versioned-api/transaction-handling.json b/test/versioned-api/transaction-handling.json index c00c5240ae..32031296af 100644 --- a/test/versioned-api/transaction-handling.json +++ b/test/versioned-api/transaction-handling.json @@ -6,7 +6,7 @@ "minServerVersion": "4.9", "topologies": [ "replicaset", - "sharded-replicaset", + "sharded", "load-balanced" ] } @@ -92,7 +92,7 @@ { "topologies": [ "replicaset", - "sharded-replicaset", + "sharded", "load-balanced" ] } @@ -221,7 +221,7 @@ { "topologies": [ "replicaset", - "sharded-replicaset", + "sharded", "load-balanced" ] } diff --git a/tools/fail_if_no_c.py b/tools/fail_if_no_c.py index 6848e155aa..64280a81d2 100644 --- a/tools/fail_if_no_c.py +++ b/tools/fail_if_no_c.py @@ -18,34 +18,30 @@ """ from __future__ import annotations -import os -import subprocess +import logging import sys -from pathlib import Path + +LOGGER = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO, format="%(levelname)-8s %(message)s") sys.path[0:0] = [""] import bson # noqa: E402 import pymongo # noqa: E402 -if not pymongo.has_c() or not bson.has_c(): - try: - from pymongo import _cmessage # type:ignore[attr-defined] # noqa: F401 - except Exception as e: - print(e) - try: - from bson import _cbson # type:ignore[attr-defined] # noqa: F401 - except Exception as e: - print(e) - sys.exit("could not load C extensions") - -if os.environ.get("ENSURE_UNIVERSAL2") == "1": - parent_dir = Path(pymongo.__path__[0]).parent - for pkg in ["pymongo", "bson", "grifs"]: - for so_file in Path(f"{parent_dir}/{pkg}").glob("*.so"): - print(f"Checking universal2 compatibility in {so_file}...") - output = subprocess.check_output(["file", so_file]) # noqa: S603, S607 - if "arm64" not in output.decode("utf-8"): - sys.exit("Universal wheel was not compiled with arm64 support") - if "x86_64" not in output.decode("utf-8"): - sys.exit("Universal wheel was not compiled with x86_64 support") + +def main() -> None: + if not pymongo.has_c() or not bson.has_c(): + try: + from pymongo import _cmessage # type:ignore[attr-defined] # noqa: F401 + except Exception as e: + LOGGER.exception(e) + try: + from bson import _cbson # type:ignore[attr-defined] # noqa: F401 + except Exception as e: + LOGGER.exception(e) + sys.exit("could not load C extensions") + + +if __name__ == "__main__": + main() diff --git a/tools/ocsptest.py b/tools/ocsptest.py index 521d048f79..8596db226d 100644 --- a/tools/ocsptest.py +++ b/tools/ocsptest.py @@ -35,6 +35,7 @@ def check_ocsp(host: str, port: int, capath: str) -> None: False, # allow_invalid_certificates False, # allow_invalid_hostnames False, + True, # is sync ) # disable_ocsp_endpoint_check # Ensure we're using pyOpenSSL. diff --git a/tools/synchro.py b/tools/synchro.py index dbcbbd1351..bfe8f71125 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -47,6 +47,7 @@ "async_receive_message": "receive_message", "async_receive_data": "receive_data", "async_sendall": "sendall", + "async_socket_sendall": "sendall", "asynchronous": "synchronous", "Asynchronous": "Synchronous", "AsyncBulkTestBase": "BulkTestBase", @@ -119,6 +120,20 @@ "_async_create_lock": "_create_lock", "_async_create_condition": "_create_condition", "_async_cond_wait": "_cond_wait", + "AsyncNetworkingInterface": "NetworkingInterface", + "_configured_protocol_interface": "_configured_socket_interface", + "_async_configured_socket": "_configured_socket", + "SpecRunnerTask": "SpecRunnerThread", + "AsyncMockConnection": "MockConnection", + "AsyncMockPool": "MockPool", + "StopAsyncIteration": "StopIteration", + "create_async_event": "create_event", + "async_create_barrier": "create_barrier", + "async_barrier_wait": "barrier_wait", + "async_joinall": "joinall", + "_async_create_connection": "_create_connection", + "pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts": "pymongo.synchronous.srv_resolver._SrvResolver.get_hosts", + "dns.asyncresolver.resolve": "dns.resolver.resolve", } docstring_replacements: dict[tuple[str, str], str] = { @@ -165,7 +180,12 @@ def async_only_test(f: str) -> bool: """Return True for async tests that should not be converted to sync.""" - return f in ["test_locks.py", "test_concurrency.py"] + return f in [ + "test_locks.py", + "test_concurrency.py", + "test_async_cancellation.py", + "test_async_loop_safety.py", + ] test_files = [ @@ -198,21 +218,60 @@ def async_only_test(f: str) -> bool: "test_comment.py", "test_common.py", "test_connection_logging.py", + "test_connection_monitoring.py", "test_connections_survive_primary_stepdown_spec.py", "test_create_entities.py", "test_crud_unified.py", + "test_csot.py", "test_cursor.py", + "test_custom_types.py", "test_database.py", + "test_data_lake.py", + "test_discovery_and_monitoring.py", + "test_dns.py", "test_encryption.py", + "test_examples.py", "test_grid_file.py", + "test_gridfs.py", + "test_gridfs_bucket.py", + "test_gridfs_spec.py", + "test_heartbeat_monitoring.py", + "test_index_management.py", + "test_json_util_integration.py", + "test_load_balancer.py", "test_logger.py", + "test_max_staleness.py", + "test_monitor.py", "test_monitoring.py", + "test_mongos_load_balancing.py", + "test_on_demand_csfle.py", + "test_pooling.py", "test_raw_bson.py", + "test_read_concern.py", + "test_read_preferences.py", + "test_read_write_concern_spec.py", "test_retryable_reads.py", + "test_retryable_reads_unified.py", "test_retryable_writes.py", + "test_retryable_writes_unified.py", + "test_run_command.py", + "test_sdam_monitoring_spec.py", + "test_server_selection.py", + "test_server_selection_in_window.py", + "test_server_selection_logging.py", + "test_server_selection_rtt.py", "test_session.py", + "test_sessions_unified.py", + "test_srv_polling.py", + "test_ssl.py", + "test_streaming_protocol.py", "test_transactions.py", + "test_transactions_unified.py", + "test_unified_format.py", + "test_versioned_api_integration.py", "unified_format.py", + "utils_selection_tests.py", + "utils.py", ] @@ -229,7 +288,8 @@ def process_files( if file in docstring_translate_files: lines = translate_docstrings(lines) if file in sync_test_files: - translate_imports(lines) + lines = translate_imports(lines) + lines = process_ignores(lines) f.seek(0) f.writelines(lines) f.truncate() @@ -331,6 +391,14 @@ def translate_docstrings(lines: list[str]) -> list[str]: return [line for line in lines if line != "DOCSTRING_REMOVED"] +def process_ignores(lines: list[str]) -> list[str]: + for i in range(len(lines)): + for k, v in replacements.items(): + if "unasync: off" in lines[i] and v in lines[i]: + lines[i] = lines[i].replace(v, k) + return lines + + def unasync_directory(files: list[str], src: str, dest: str, replacements: dict[str, str]) -> None: unasync_files( files, diff --git a/tools/synchro.sh b/tools/synchro.sh index 51c51a9548..28b9c6d6c4 100755 --- a/tools/synchro.sh +++ b/tools/synchro.sh @@ -1,5 +1,5 @@ #!/bin/bash - +# Keep the synchronous folders in sync with there async counterparts. set -eu python ./tools/synchro.py "$@" diff --git a/uv.lock b/uv.lock index e7f09f66fc..87d94ae76e 100644 --- a/uv.lock +++ b/uv.lock @@ -1,4 +1,5 @@ version = 1 +revision = 1 requires-python = ">=3.9" resolution-markers = [ "python_full_version >= '3.10'", @@ -997,7 +998,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/07/e9/ae44ea7d7605df9e5 [[package]] name = "pymongo" -version = "4.11.0.dev0" +version = "4.12.0" source = { editable = "." } dependencies = [ { name = "dnspython" }, @@ -1063,6 +1064,9 @@ mockupdb = [ perf = [ { name = "simplejson" }, ] +pip = [ + { name = "pip" }, +] pymongocrypt-source = [ { name = "pymongocrypt" }, ] @@ -1083,7 +1087,7 @@ requires-dist = [ { name = "pykerberos", marker = "os_name != 'nt' and extra == 'gssapi'" }, { name = "pymongo-auth-aws", marker = "extra == 'aws'", specifier = ">=1.1.0,<2.0.0" }, { name = "pymongo-auth-aws", marker = "extra == 'encryption'", specifier = ">=1.1.0,<2.0.0" }, - { name = "pymongocrypt", marker = "extra == 'encryption'", specifier = ">=1.12.0,<2.0.0" }, + { name = "pymongocrypt", marker = "extra == 'encryption'", specifier = ">=1.13.0,<2.0.0" }, { name = "pyopenssl", marker = "extra == 'ocsp'", specifier = ">=17.2.0" }, { name = "pytest", marker = "extra == 'test'", specifier = ">=8.2" }, { name = "pytest-asyncio", marker = "extra == 'test'", specifier = ">=0.24.0" }, @@ -1098,6 +1102,7 @@ requires-dist = [ { name = "winkerberos", marker = "os_name == 'nt' and extra == 'gssapi'", specifier = ">=0.5.0" }, { name = "zstandard", marker = "extra == 'zstd'" }, ] +provides-extras = ["aws", "docs", "encryption", "gssapi", "ocsp", "snappy", "test", "zstd"] [package.metadata.requires-dev] coverage = [ @@ -1109,6 +1114,7 @@ eventlet = [{ name = "eventlet" }] gevent = [{ name = "gevent" }] mockupdb = [{ name = "mockupdb", git = "https://github.com/mongodb-labs/mongo-mockup-db?rev=master" }] perf = [{ name = "simplejson" }] +pip = [{ name = "pip" }] pymongocrypt-source = [{ name = "pymongocrypt", git = "https://github.com/mongodb/libmongocrypt?subdirectory=bindings%2Fpython&rev=master" }] typing = [ { name = "mypy", specifier = "==1.14.1" }, @@ -1132,8 +1138,8 @@ wheels = [ [[package]] name = "pymongocrypt" -version = "1.13.0.dev0" -source = { git = "https://github.com/mongodb/libmongocrypt?subdirectory=bindings%2Fpython&rev=master#90476d5db7737bab2ce1c198df5671a12dbaae1a" } +version = "1.14.0.dev0" +source = { git = "https://github.com/mongodb/libmongocrypt?subdirectory=bindings%2Fpython&rev=master#af621673c46d3d8fd2a2fe9d5540e24a79d9357a" } dependencies = [ { name = "cffi" }, { name = "cryptography" },