Skip to content

Rename result type in some docs and tests #1462

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ from pydantic_ai.exceptions import UsageLimitExceeded
from pydantic_ai.usage import UsageLimits


class NeverResultType(TypedDict):
class NeverOutputType(TypedDict):
"""
Never ever coerce data to this type.
"""
Expand All @@ -429,7 +429,7 @@ class NeverResultType(TypedDict):
agent = Agent(
'anthropic:claude-3-5-sonnet-latest',
retries=3,
output_type=NeverResultType,
output_type=NeverOutputType,
system_prompt='Any time you get a response, call the `infinite_retry_tool` to produce another response.',
)

Expand Down
24 changes: 12 additions & 12 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,14 +437,14 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any:
@pytest.mark.parametrize(
'union_code',
[
pytest.param('ResultType = Union[Foo, Bar]'),
pytest.param('ResultType = Foo | Bar', marks=pytest.mark.skipif(sys.version_info < (3, 10), reason='3.10+')),
pytest.param('OutputType = Union[Foo, Bar]'),
pytest.param('OutputType = Foo | Bar', marks=pytest.mark.skipif(sys.version_info < (3, 10), reason='3.10+')),
pytest.param(
'ResultType: TypeAlias = Foo | Bar',
'OutputType: TypeAlias = Foo | Bar',
marks=pytest.mark.skipif(sys.version_info < (3, 10), reason='Python 3.10+'),
),
pytest.param(
'type ResultType = Foo | Bar', marks=pytest.mark.skipif(sys.version_info < (3, 12), reason='3.12+')
'type OutputType = Foo | Bar', marks=pytest.mark.skipif(sys.version_info < (3, 12), reason='3.12+')
),
],
)
Expand All @@ -470,7 +470,7 @@ class Bar(BaseModel):
mod = create_module(module_code)

m = TestModel()
agent = Agent(m, output_type=mod.ResultType)
agent = Agent(m, output_type=mod.OutputType)
got_tool_call_name = 'unset'

@agent.output_validator
Expand Down Expand Up @@ -983,7 +983,7 @@ class TestMultipleToolCalls:

pytestmark = pytest.mark.usefixtures('set_event_loop')

class ResultType(BaseModel):
class OutputType(BaseModel):
"""Result type used by all tests."""

value: str
Expand All @@ -1002,7 +1002,7 @@ def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
]
)

agent = Agent(FunctionModel(return_model), output_type=self.ResultType, end_strategy='early')
agent = Agent(FunctionModel(return_model), output_type=self.OutputType, end_strategy='early')

@agent.tool_plain
def regular_tool(x: int) -> int: # pragma: no cover
Expand Down Expand Up @@ -1058,7 +1058,7 @@ def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
]
)

agent = Agent(FunctionModel(return_model), output_type=self.ResultType, end_strategy='early')
agent = Agent(FunctionModel(return_model), output_type=self.OutputType, end_strategy='early')
result = agent.run_sync('test multiple final results')

# Verify the result came from the first final tool
Expand Down Expand Up @@ -1098,7 +1098,7 @@ def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
]
)

agent = Agent(FunctionModel(return_model), output_type=self.ResultType, end_strategy='exhaustive')
agent = Agent(FunctionModel(return_model), output_type=self.OutputType, end_strategy='exhaustive')

@agent.tool_plain
def regular_tool(x: int) -> int:
Expand Down Expand Up @@ -1186,7 +1186,7 @@ def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
]
)

agent = Agent(FunctionModel(return_model), output_type=self.ResultType, end_strategy='early')
agent = Agent(FunctionModel(return_model), output_type=self.OutputType, end_strategy='early')

@agent.tool_plain
def regular_tool(x: int) -> int: # pragma: no cover
Expand Down Expand Up @@ -1259,7 +1259,7 @@ def another_tool(y: int) -> int: # pragma: no cover
def test_early_strategy_does_not_apply_to_tool_calls_without_final_tool(self):
"""Test that 'early' strategy does not apply to tool calls without final tool."""
tool_called = []
agent = Agent(TestModel(), output_type=self.ResultType, end_strategy='early')
agent = Agent(TestModel(), output_type=self.OutputType, end_strategy='early')

@agent.tool_plain
def regular_tool(x: int) -> int:
Expand All @@ -1285,7 +1285,7 @@ def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
]
)

agent = Agent(FunctionModel(return_model), output_type=self.ResultType, end_strategy='early')
agent = Agent(FunctionModel(return_model), output_type=self.OutputType, end_strategy='early')
result = agent.run_sync('test multiple final results')

# Verify the result came from the second final tool
Expand Down
34 changes: 17 additions & 17 deletions tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ async def ret_a(x: str) -> str: # pragma: no cover
)


class ResultType(BaseModel):
class OutputType(BaseModel):
"""Result type used by all tests."""

value: str
Expand All @@ -407,7 +407,7 @@ async def sf(_: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str | Delt
yield {2: DeltaToolCall('regular_tool', '{"x": 1}')}
yield {3: DeltaToolCall('another_tool', '{"y": 2}')}

agent = Agent(FunctionModel(stream_function=sf), output_type=ResultType, end_strategy='early')
agent = Agent(FunctionModel(stream_function=sf), output_type=OutputType, end_strategy='early')

@agent.tool_plain
def regular_tool(x: int) -> int: # pragma: no cover
Expand Down Expand Up @@ -476,7 +476,7 @@ async def sf(_: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str | Delt
yield {1: DeltaToolCall('final_result', '{"value": "first"}')}
yield {2: DeltaToolCall('final_result', '{"value": "second"}')}

agent = Agent(FunctionModel(stream_function=sf), output_type=ResultType, end_strategy='early')
agent = Agent(FunctionModel(stream_function=sf), output_type=OutputType, end_strategy='early')

async with agent.run_stream('test multiple final results') as result:
response = await result.get_output()
Expand Down Expand Up @@ -529,7 +529,7 @@ async def sf(_: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str | Delt
yield {4: DeltaToolCall('final_result', '{"value": "second"}')}
yield {5: DeltaToolCall('unknown_tool', '{"value": "???"}')}

agent = Agent(FunctionModel(stream_function=sf), output_type=ResultType, end_strategy='exhaustive')
agent = Agent(FunctionModel(stream_function=sf), output_type=OutputType, end_strategy='exhaustive')

@agent.tool_plain
def regular_tool(x: int) -> int:
Expand Down Expand Up @@ -606,7 +606,7 @@ async def sf(_: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str | Delt
yield {3: DeltaToolCall('another_tool', '{"y": 2}')}
yield {4: DeltaToolCall('unknown_tool', '{"value": "???"}')}

agent = Agent(FunctionModel(stream_function=sf), output_type=ResultType, end_strategy='early')
agent = Agent(FunctionModel(stream_function=sf), output_type=OutputType, end_strategy='early')

@agent.tool_plain
def regular_tool(x: int) -> int: # pragma: no cover
Expand Down Expand Up @@ -715,7 +715,7 @@ def another_tool(y: int) -> int: # pragma: no cover
async def test_early_strategy_does_not_apply_to_tool_calls_without_final_tool():
"""Test that 'early' strategy does not apply to tool calls without final tool."""
tool_called: list[str] = []
agent = Agent(TestModel(), output_type=ResultType, end_strategy='early')
agent = Agent(TestModel(), output_type=OutputType, end_strategy='early')

@agent.tool_plain
def regular_tool(x: int) -> int:
Expand Down Expand Up @@ -777,17 +777,17 @@ async def test_custom_output_type_default_str() -> None:
response = await result.get_output()
assert response == snapshot('success (no tool calls)')

async with agent.run_stream('test', output_type=ResultType) as result:
async with agent.run_stream('test', output_type=OutputType) as result:
response = await result.get_output()
assert response == snapshot(ResultType(value='a'))
assert response == snapshot(OutputType(value='a'))


async def test_custom_output_type_default_structured() -> None:
agent = Agent('test', output_type=ResultType)
agent = Agent('test', output_type=OutputType)

async with agent.run_stream('test') as result:
response = await result.get_output()
assert response == snapshot(ResultType(value='a'))
assert response == snapshot(OutputType(value='a'))

async with agent.run_stream('test', output_type=str) as result:
response = await result.get_output()
Expand Down Expand Up @@ -880,21 +880,21 @@ def output_validator_simple(data: str) -> str:


async def test_stream_iter_structured_validator() -> None:
class NotResultType(BaseModel):
class NotOutputType(BaseModel):
not_value: str

agent = Agent[None, Union[ResultType, NotResultType]]('test', output_type=Union[ResultType, NotResultType]) # pyright: ignore[reportArgumentType]
agent = Agent[None, Union[OutputType, NotOutputType]]('test', output_type=Union[OutputType, NotOutputType]) # pyright: ignore[reportArgumentType]

@agent.output_validator
def output_validator(data: ResultType | NotResultType) -> ResultType | NotResultType:
assert isinstance(data, ResultType)
return ResultType(value=data.value + ' (validated)')
def output_validator(data: OutputType | NotOutputType) -> OutputType | NotOutputType:
assert isinstance(data, OutputType)
return OutputType(value=data.value + ' (validated)')

outputs: list[ResultType] = []
outputs: list[OutputType] = []
async with agent.iter('test') as run:
async for node in run:
if agent.is_model_request_node(node):
async with node.stream(run.ctx) as stream:
async for output in stream.stream_output(debounce_by=None):
outputs.append(output)
assert outputs == [ResultType(value='a (validated)'), ResultType(value='a (validated)')]
assert outputs == [OutputType(value='a (validated)'), OutputType(value='a (validated)')]