Skip to content

Commit 3dc86d3

Browse files
mkheckmarkpollack
authored andcommitted
Added Entra ID identity management for Azure OpenAI, clean autoconfiguration, and updated docs to reflect changes.
Signed-off-by: Mark Heckler <[email protected]>
1 parent 228ef10 commit 3dc86d3

File tree

6 files changed

+330
-42
lines changed

6 files changed

+330
-42
lines changed

auto-configurations/models/spring-ai-autoconfigure-model-azure-openai/pom.xml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,12 @@
103103
<artifactId>mockito-core</artifactId>
104104
<scope>test</scope>
105105
</dependency>
106-
</dependencies>
106+
<dependency>
107+
<groupId>com.azure</groupId>
108+
<artifactId>azure-identity</artifactId>
109+
<version>${azure-identity.version}</version>
110+
<scope>compile</scope>
111+
</dependency>
112+
</dependencies>
107113

108114
</project>

auto-configurations/models/spring-ai-autoconfigure-model-azure-openai/src/main/java/org/springframework/ai/model/azure/openai/autoconfigure/AzureOpenAiClientBuilderConfiguration.java

Lines changed: 21 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,11 @@
2323
import com.azure.ai.openai.OpenAIClientBuilder;
2424
import com.azure.core.credential.AzureKeyCredential;
2525
import com.azure.core.credential.KeyCredential;
26-
import com.azure.core.credential.TokenCredential;
2726
import com.azure.core.util.ClientOptions;
2827
import com.azure.core.util.Header;
28+
import com.azure.identity.DefaultAzureCredentialBuilder;
2929

3030
import org.springframework.beans.factory.ObjectProvider;
31-
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
3231
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
3332
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
3433
import org.springframework.boot.context.properties.EnableConfigurationProperties;
@@ -55,48 +54,39 @@ public class AzureOpenAiClientBuilderConfiguration {
5554
public OpenAIClientBuilder openAIClientBuilder(AzureOpenAiConnectionProperties connectionProperties,
5655
ObjectProvider<AzureOpenAIClientBuilderCustomizer> customizers) {
5756

58-
if (StringUtils.hasText(connectionProperties.getApiKey())) {
59-
60-
Assert.hasText(connectionProperties.getEndpoint(), "Endpoint must not be empty");
61-
62-
Map<String, String> customHeaders = connectionProperties.getCustomHeaders();
63-
List<Header> headers = customHeaders.entrySet()
64-
.stream()
65-
.map(entry -> new Header(entry.getKey(), entry.getValue()))
66-
.collect(Collectors.toList());
67-
ClientOptions clientOptions = new ClientOptions().setApplicationId(APPLICATION_ID).setHeaders(headers);
68-
OpenAIClientBuilder clientBuilder = new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint())
69-
.credential(new AzureKeyCredential(connectionProperties.getApiKey()))
70-
.clientOptions(clientOptions);
71-
applyOpenAIClientBuilderCustomizers(clientBuilder, customizers);
72-
return clientBuilder;
73-
}
57+
final OpenAIClientBuilder clientBuilder;
7458

7559
// Connect to OpenAI (e.g. not the Azure OpenAI). The deploymentName property is
7660
// used as OpenAI model name.
7761
if (StringUtils.hasText(connectionProperties.getOpenAiApiKey())) {
78-
OpenAIClientBuilder clientBuilder = new OpenAIClientBuilder().endpoint("https://api.openai.com/v1")
62+
clientBuilder = new OpenAIClientBuilder().endpoint("https://api.openai.com/v1")
7963
.credential(new KeyCredential(connectionProperties.getOpenAiApiKey()))
8064
.clientOptions(new ClientOptions().setApplicationId(APPLICATION_ID));
8165
applyOpenAIClientBuilderCustomizers(clientBuilder, customizers);
8266
return clientBuilder;
8367
}
8468

85-
throw new IllegalArgumentException("Either API key or OpenAI API key must not be empty");
86-
}
69+
Map<String, String> customHeaders = connectionProperties.getCustomHeaders();
70+
List<Header> headers = customHeaders.entrySet()
71+
.stream()
72+
.map(entry -> new Header(entry.getKey(), entry.getValue()))
73+
.collect(Collectors.toList());
74+
ClientOptions clientOptions = new ClientOptions().setApplicationId(APPLICATION_ID).setHeaders(headers);
8775

88-
@Bean
89-
@ConditionalOnMissingBean
90-
@ConditionalOnBean(TokenCredential.class)
91-
public OpenAIClientBuilder openAIClientWithTokenCredential(AzureOpenAiConnectionProperties connectionProperties,
92-
TokenCredential tokenCredential, ObjectProvider<AzureOpenAIClientBuilderCustomizer> customizers) {
93-
94-
Assert.notNull(tokenCredential, "TokenCredential must not be null");
9576
Assert.hasText(connectionProperties.getEndpoint(), "Endpoint must not be empty");
9677

97-
OpenAIClientBuilder clientBuilder = new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint())
98-
.credential(tokenCredential)
99-
.clientOptions(new ClientOptions().setApplicationId(APPLICATION_ID));
78+
if (!StringUtils.hasText(connectionProperties.getApiKey())) {
79+
// Entra ID configuration, as the API key is not set
80+
clientBuilder = new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint())
81+
.credential(new DefaultAzureCredentialBuilder().build())
82+
.clientOptions(clientOptions);
83+
}
84+
else {
85+
// Azure OpenAI configuration using API key and endpoint
86+
clientBuilder = new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint())
87+
.credential(new AzureKeyCredential(connectionProperties.getApiKey()))
88+
.clientOptions(clientOptions);
89+
}
10090
applyOpenAIClientBuilderCustomizers(clientBuilder, customizers);
10191
return clientBuilder;
10292
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.model.azure.openai.autoconfigure;
18+
19+
import java.lang.reflect.Field;
20+
import java.net.URI;
21+
import java.util.List;
22+
import java.util.Map;
23+
import java.util.concurrent.atomic.AtomicBoolean;
24+
import java.util.stream.Collectors;
25+
26+
import com.azure.ai.openai.OpenAIClient;
27+
import com.azure.ai.openai.OpenAIClientBuilder;
28+
import com.azure.ai.openai.implementation.OpenAIClientImpl;
29+
import com.azure.core.http.HttpHeader;
30+
import com.azure.core.http.HttpHeaderName;
31+
import com.azure.core.http.HttpMethod;
32+
import com.azure.core.http.HttpPipeline;
33+
import com.azure.core.http.HttpRequest;
34+
import com.azure.core.http.HttpResponse;
35+
import org.junit.jupiter.api.Disabled;
36+
import org.junit.jupiter.api.Test;
37+
import org.junit.jupiter.api.condition.DisabledIfEnvironmentVariable;
38+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
39+
import reactor.core.publisher.Flux;
40+
41+
import org.springframework.ai.azure.openai.AzureOpenAiAudioTranscriptionModel;
42+
import org.springframework.ai.azure.openai.AzureOpenAiChatModel;
43+
import org.springframework.ai.azure.openai.AzureOpenAiEmbeddingModel;
44+
import org.springframework.ai.chat.messages.AssistantMessage;
45+
import org.springframework.ai.chat.messages.Message;
46+
import org.springframework.ai.chat.messages.UserMessage;
47+
import org.springframework.ai.chat.model.ChatResponse;
48+
import org.springframework.ai.chat.model.Generation;
49+
import org.springframework.ai.chat.prompt.Prompt;
50+
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
51+
import org.springframework.ai.embedding.EmbeddingResponse;
52+
import org.springframework.boot.autoconfigure.AutoConfigurations;
53+
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
54+
import org.springframework.core.io.ClassPathResource;
55+
import org.springframework.core.io.Resource;
56+
import org.springframework.util.ReflectionUtils;
57+
58+
import static org.assertj.core.api.Assertions.assertThat;
59+
60+
/**
61+
* @author Christian Tzolov
62+
* @author Piotr Olaszewski
63+
* @author Soby Chacko
64+
* @author Manuel Andreo Garcia
65+
* @since 0.8.0
66+
*/
67+
@DisabledIfEnvironmentVariable(named = "AZURE_OPENAI_API_KEY", matches = ".+")
68+
@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_ENDPOINT", matches = ".+")
69+
@Disabled("IT test environment does not have Entra configured. This test needs to be run manually.")
70+
class AzureOpenAiAutoConfigurationEntraIT {
71+
72+
private static String CHAT_MODEL_NAME = "gpt-4o";
73+
74+
private static String EMBEDDING_MODEL_NAME = "text-embedding-ada-002";
75+
76+
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner().withPropertyValues(
77+
// @formatter:off
78+
"spring.ai.azure.openai.endpoint=" + System.getenv("AZURE_OPENAI_ENDPOINT"),
79+
80+
"spring.ai.azure.openai.chat.options.deployment-name=" + CHAT_MODEL_NAME,
81+
"spring.ai.azure.openai.chat.options.temperature=0.8",
82+
"spring.ai.azure.openai.chat.options.maxTokens=123",
83+
84+
"spring.ai.azure.openai.embedding.options.deployment-name=" + EMBEDDING_MODEL_NAME,
85+
"spring.ai.azure.openai.audio.transcription.options.deployment-name=" + System.getenv("AZURE_OPENAI_TRANSCRIPTION_DEPLOYMENT_NAME")
86+
// @formatter:on
87+
);
88+
89+
private final Message systemMessage = new SystemPromptTemplate("""
90+
You are a helpful AI assistant. Your name is {name}.
91+
You are an AI assistant that helps people find information.
92+
Your name is {name}
93+
You should reply to the user's request with your name and also in the style of a {voice}.
94+
""").createMessage(Map.of("name", "Bob", "voice", "pirate"));
95+
96+
private final UserMessage userMessage = new UserMessage(
97+
"Tell me about 3 famous pirates from the Golden Age of Piracy and why they did.");
98+
99+
@Test
100+
void chatCompletion() {
101+
this.contextRunner.withConfiguration(AutoConfigurations.of(AzureOpenAiChatAutoConfiguration.class))
102+
.run(context -> {
103+
AzureOpenAiChatModel chatModel = context.getBean(AzureOpenAiChatModel.class);
104+
ChatResponse response = chatModel.call(new Prompt(List.of(this.userMessage, this.systemMessage)));
105+
assertThat(response.getResult().getOutput().getText()).contains("Blackbeard");
106+
});
107+
}
108+
109+
@Test
110+
void httpRequestContainsUserAgentAndCustomHeaders() {
111+
this.contextRunner.withConfiguration(AutoConfigurations.of(AzureOpenAiChatAutoConfiguration.class))
112+
.withPropertyValues("spring.ai.azure.openai.custom-headers.foo=bar",
113+
"spring.ai.azure.openai.custom-headers.fizz=buzz")
114+
.run(context -> {
115+
OpenAIClientBuilder openAIClientBuilder = context.getBean(OpenAIClientBuilder.class);
116+
OpenAIClient openAIClient = openAIClientBuilder.buildClient();
117+
Field serviceClientField = ReflectionUtils.findField(OpenAIClient.class, "serviceClient");
118+
assertThat(serviceClientField).isNotNull();
119+
ReflectionUtils.makeAccessible(serviceClientField);
120+
OpenAIClientImpl oaci = (OpenAIClientImpl) ReflectionUtils.getField(serviceClientField, openAIClient);
121+
assertThat(oaci).isNotNull();
122+
HttpPipeline httpPipeline = oaci.getHttpPipeline();
123+
HttpResponse httpResponse = httpPipeline
124+
.send(new HttpRequest(HttpMethod.POST, new URI(System.getenv("AZURE_OPENAI_ENDPOINT")).toURL()))
125+
.block();
126+
assertThat(httpResponse).isNotNull();
127+
HttpHeader httpHeader = httpResponse.getRequest().getHeaders().get(HttpHeaderName.USER_AGENT);
128+
assertThat(httpHeader.getValue().startsWith("spring-ai azsdk-java-azure-ai-openai/")).isTrue();
129+
HttpHeader customHeader1 = httpResponse.getRequest().getHeaders().get("foo");
130+
assertThat(customHeader1.getValue()).isEqualTo("bar");
131+
HttpHeader customHeader2 = httpResponse.getRequest().getHeaders().get("fizz");
132+
assertThat(customHeader2.getValue()).isEqualTo("buzz");
133+
});
134+
}
135+
136+
@Test
137+
void chatCompletionStreaming() {
138+
this.contextRunner.withConfiguration(AutoConfigurations.of(AzureOpenAiChatAutoConfiguration.class))
139+
.run(context -> {
140+
141+
AzureOpenAiChatModel chatModel = context.getBean(AzureOpenAiChatModel.class);
142+
143+
Flux<ChatResponse> response = chatModel
144+
.stream(new Prompt(List.of(this.userMessage, this.systemMessage)));
145+
146+
List<ChatResponse> responses = response.collectList().block();
147+
assertThat(responses.size()).isGreaterThan(10);
148+
149+
String stitchedResponseContent = responses.stream()
150+
.map(ChatResponse::getResults)
151+
.flatMap(List::stream)
152+
.map(Generation::getOutput)
153+
.map(AssistantMessage::getText)
154+
.collect(Collectors.joining());
155+
156+
assertThat(stitchedResponseContent).contains("Blackbeard");
157+
});
158+
}
159+
160+
@Test
161+
void embedding() {
162+
this.contextRunner.withConfiguration(AutoConfigurations.of(AzureOpenAiEmbeddingAutoConfiguration.class))
163+
.run(context -> {
164+
AzureOpenAiEmbeddingModel embeddingModel = context.getBean(AzureOpenAiEmbeddingModel.class);
165+
166+
EmbeddingResponse embeddingResponse = embeddingModel
167+
.embedForResponse(List.of("Hello World", "World is big and salvation is near"));
168+
assertThat(embeddingResponse.getResults()).hasSize(2);
169+
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
170+
assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0);
171+
assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty();
172+
assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1);
173+
174+
assertThat(embeddingModel.dimensions()).isEqualTo(1536);
175+
});
176+
177+
}
178+
179+
@Test
180+
@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_TRANSCRIPTION_DEPLOYMENT_NAME", matches = ".+")
181+
void transcribe() {
182+
this.contextRunner
183+
.withConfiguration(AutoConfigurations.of(AzureOpenAiAudioTranscriptionAutoConfiguration.class))
184+
.run(context -> {
185+
AzureOpenAiAudioTranscriptionModel transcriptionModel = context
186+
.getBean(AzureOpenAiAudioTranscriptionModel.class);
187+
Resource audioFile = new ClassPathResource("/speech/jfk.flac");
188+
String response = transcriptionModel.call(audioFile);
189+
assertThat(response).isEqualTo(
190+
"And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.");
191+
});
192+
}
193+
194+
@Test
195+
void chatActivation() {
196+
197+
// Disable the chat auto-configuration.
198+
this.contextRunner.withConfiguration(AutoConfigurations.of(AzureOpenAiChatAutoConfiguration.class))
199+
.withPropertyValues("spring.ai.model.chat=none")
200+
.run(context -> {
201+
assertThat(context.getBeansOfType(AzureOpenAiChatProperties.class)).isEmpty();
202+
assertThat(context.getBeansOfType(AzureOpenAiChatModel.class)).isEmpty();
203+
});
204+
205+
// The chat auto-configuration is enabled by default.
206+
this.contextRunner.withConfiguration(AutoConfigurations.of(AzureOpenAiChatAutoConfiguration.class))
207+
.run(context -> {
208+
assertThat(context.getBeansOfType(AzureOpenAiChatModel.class)).isNotEmpty();
209+
assertThat(context.getBeansOfType(AzureOpenAiChatProperties.class)).isNotEmpty();
210+
});
211+
212+
// Explicitly enable the chat auto-configuration.
213+
this.contextRunner.withConfiguration(AutoConfigurations.of(AzureOpenAiChatAutoConfiguration.class))
214+
.withPropertyValues("spring.ai.model.chat=azure-openai")
215+
.run(context -> {
216+
assertThat(context.getBeansOfType(AzureOpenAiChatModel.class)).isNotEmpty();
217+
assertThat(context.getBeansOfType(AzureOpenAiChatProperties.class)).isNotEmpty();
218+
});
219+
}
220+
221+
@Test
222+
void embeddingActivation() {
223+
224+
// Disable the embedding auto-configuration.
225+
this.contextRunner.withConfiguration(AutoConfigurations.of(AzureOpenAiEmbeddingAutoConfiguration.class))
226+
.withPropertyValues("spring.ai.model.embedding=none")
227+
.run(context -> {
228+
assertThat(context.getBeansOfType(AzureOpenAiEmbeddingModel.class)).isEmpty();
229+
assertThat(context.getBeansOfType(AzureOpenAiEmbeddingProperties.class)).isEmpty();
230+
});
231+
232+
// The embedding auto-configuration is enabled by default.
233+
this.contextRunner.withConfiguration(AutoConfigurations.of(AzureOpenAiEmbeddingAutoConfiguration.class))
234+
.run(context -> {
235+
assertThat(context.getBeansOfType(AzureOpenAiEmbeddingModel.class)).isNotEmpty();
236+
assertThat(context.getBeansOfType(AzureOpenAiEmbeddingProperties.class)).isNotEmpty();
237+
});
238+
239+
// Explicitly enable the embedding auto-configuration.
240+
this.contextRunner.withConfiguration(AutoConfigurations.of(AzureOpenAiEmbeddingAutoConfiguration.class))
241+
.withPropertyValues("spring.ai.model.embedding=azure-openai")
242+
.run(context -> {
243+
assertThat(context.getBeansOfType(AzureOpenAiEmbeddingModel.class)).isNotEmpty();
244+
assertThat(context.getBeansOfType(AzureOpenAiEmbeddingProperties.class)).isNotEmpty();
245+
});
246+
}
247+
248+
@Test
249+
void audioTranscriptionActivation() {
250+
251+
// Disable the transcription auto-configuration.
252+
this.contextRunner
253+
.withConfiguration(AutoConfigurations.of(AzureOpenAiAudioTranscriptionAutoConfiguration.class))
254+
.withPropertyValues("spring.ai.model.audio.transcription=none")
255+
.run(context -> {
256+
assertThat(context.getBeansOfType(AzureOpenAiAudioTranscriptionModel.class)).isEmpty();
257+
assertThat(context.getBeansOfType(AzureOpenAiAudioTranscriptionProperties.class)).isEmpty();
258+
});
259+
260+
// The transcription auto-configuration is enabled by default.
261+
this.contextRunner
262+
.withConfiguration(AutoConfigurations.of(AzureOpenAiAudioTranscriptionAutoConfiguration.class))
263+
.run(context -> assertThat(context.getBeansOfType(AzureOpenAiAudioTranscriptionModel.class)).isNotEmpty());
264+
265+
// Explicitly enable the transcription auto-configuration.
266+
this.contextRunner
267+
.withConfiguration(AutoConfigurations.of(AzureOpenAiAudioTranscriptionAutoConfiguration.class))
268+
.withPropertyValues("spring.ai.model.audio.transcription=azure-openai")
269+
.run(context -> assertThat(context.getBeansOfType(AzureOpenAiAudioTranscriptionModel.class)).isNotEmpty());
270+
}
271+
272+
@Test
273+
void openAIClientBuilderCustomizer() {
274+
AtomicBoolean firstCustomizationApplied = new AtomicBoolean(false);
275+
AtomicBoolean secondCustomizationApplied = new AtomicBoolean(false);
276+
this.contextRunner.withConfiguration(AutoConfigurations.of(AzureOpenAiChatAutoConfiguration.class))
277+
.withBean("first", AzureOpenAIClientBuilderCustomizer.class,
278+
() -> clientBuilder -> firstCustomizationApplied.set(true))
279+
.withBean("second", AzureOpenAIClientBuilderCustomizer.class,
280+
() -> clientBuilder -> secondCustomizationApplied.set(true))
281+
.run(context -> {
282+
context.getBean(OpenAIClientBuilder.class);
283+
assertThat(firstCustomizationApplied.get()).isTrue();
284+
assertThat(secondCustomizationApplied.get()).isTrue();
285+
});
286+
}
287+
288+
}

auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-azure-cosmos-db/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
<dependency>
4949
<groupId>com.azure</groupId>
5050
<artifactId>azure-identity</artifactId>
51-
<version>1.15.4</version> <!-- or the latest version -->
51+
<version>${azure-identity.version}</version>
5252
</dependency>
5353
<dependency>
5454
<groupId>org.springframework.boot</groupId>

0 commit comments

Comments
 (0)