package org.springframework.ai.openai;

import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.metadata.OpenAiChatResponseMetadata;
import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import reactor.core.publisher.Flux;

/* loaded from: input_file:org/springframework/ai/openai/OpenAiChatClient.class */
public class OpenAiChatClient extends AbstractFunctionCallSupport<OpenAiApi.ChatCompletionMessage, OpenAiApi.ChatCompletionRequest, ResponseEntity<OpenAiApi.ChatCompletion>> implements ChatClient, StreamingChatClient {
    private static final Logger logger = LoggerFactory.getLogger(OpenAiChatClient.class);
    private OpenAiChatOptions defaultOptions;
    public final RetryTemplate retryTemplate;
    private final OpenAiApi openAiApi;

    public OpenAiChatClient(OpenAiApi openAiApi) {
        this(openAiApi, OpenAiChatOptions.builder().withModel(OpenAiApi.DEFAULT_CHAT_MODEL).withTemperature(Float.valueOf(0.7f)).build());
    }

    public OpenAiChatClient(OpenAiApi openAiApi, OpenAiChatOptions openAiChatOptions) {
        this(openAiApi, openAiChatOptions, null, RetryUtils.DEFAULT_RETRY_TEMPLATE);
    }

    public OpenAiChatClient(OpenAiApi openAiApi, OpenAiChatOptions openAiChatOptions, FunctionCallbackContext functionCallbackContext, RetryTemplate retryTemplate) {
        super(functionCallbackContext);
        Assert.notNull(openAiApi, "OpenAiApi must not be null");
        Assert.notNull(openAiChatOptions, "Options must not be null");
        Assert.notNull(retryTemplate, "RetryTemplate must not be null");
        this.openAiApi = openAiApi;
        this.defaultOptions = openAiChatOptions;
        this.retryTemplate = retryTemplate;
    }

    public ChatResponse call(Prompt prompt) {
        OpenAiApi.ChatCompletionRequest createRequest = createRequest(prompt, false);
        return (ChatResponse) this.retryTemplate.execute(retryContext -> {
            ResponseEntity responseEntity = (ResponseEntity) callWithFunctionSupport(createRequest);
            OpenAiApi.ChatCompletion chatCompletion = (OpenAiApi.ChatCompletion) responseEntity.getBody();
            if (chatCompletion == null) {
                logger.warn("No chat completion returned for prompt: {}", prompt);
                return new ChatResponse(List.of());
            }
            return new ChatResponse(chatCompletion.choices().stream().map(choice -> {
                return new Generation(choice.message().content(), toMap(chatCompletion.id(), choice)).withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), (Object) null));
            }).toList(), OpenAiChatResponseMetadata.from((OpenAiApi.ChatCompletion) responseEntity.getBody()).withRateLimit(OpenAiResponseHeaderExtractor.extractAiResponseHeaders(responseEntity)));
        });
    }

    private Map<String, Object> toMap(String str, OpenAiApi.ChatCompletion.Choice choice) {
        HashMap hashMap = new HashMap();
        OpenAiApi.ChatCompletionMessage message = choice.message();
        if (message.role() != null) {
            hashMap.put("role", message.role().name());
        }
        if (choice.finishReason() != null) {
            hashMap.put("finishReason", choice.finishReason().name());
        }
        hashMap.put("id", str);
        return hashMap;
    }

    public Flux<ChatResponse> stream(Prompt prompt) {
        OpenAiApi.ChatCompletionRequest createRequest = createRequest(prompt, true);
        return (Flux) this.retryTemplate.execute(retryContext -> {
            Flux<OpenAiApi.ChatCompletionChunk> chatCompletionStream = this.openAiApi.chatCompletionStream(createRequest);
            ConcurrentHashMap concurrentHashMap = new ConcurrentHashMap();
            return chatCompletionStream.map(chatCompletionChunk -> {
                return chunkToChatCompletion(chatCompletionChunk);
            }).map(chatCompletion -> {
                try {
                    OpenAiApi.ChatCompletion chatCompletion = (OpenAiApi.ChatCompletion) ((ResponseEntity) handleFunctionCallOrReturn(createRequest, ResponseEntity.of(Optional.of(chatCompletion)))).getBody();
                    String id = chatCompletion.id();
                    return new ChatResponse(chatCompletion.choices().stream().map(choice -> {
                        if (choice.message().role() != null) {
                            concurrentHashMap.putIfAbsent(id, choice.message().role().name());
                        }
                        Generation generation = new Generation(choice.message().content(), Map.of("id", id, "role", concurrentHashMap.get(id), "finishReason", choice.finishReason() != null ? choice.finishReason().name() : ""));
                        if (choice.finishReason() != null) {
                            generation = generation.withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), (Object) null));
                        }
                        return generation;
                    }).toList());
                } catch (Exception e) {
                    logger.error("Error processing chat completion", e);
                    return new ChatResponse(List.of());
                }
            });
        });
    }

    private OpenAiApi.ChatCompletion chunkToChatCompletion(OpenAiApi.ChatCompletionChunk chatCompletionChunk) {
        return new OpenAiApi.ChatCompletion(chatCompletionChunk.id(), chatCompletionChunk.choices().stream().map(chunkChoice -> {
            return new OpenAiApi.ChatCompletion.Choice(chunkChoice.finishReason(), chunkChoice.index(), chunkChoice.delta(), chunkChoice.logprobs());
        }).toList(), chatCompletionChunk.created(), chatCompletionChunk.model(), chatCompletionChunk.systemFingerprint(), "chat.completion", null);
    }

    OpenAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean z) {
        HashSet hashSet = new HashSet();
        OpenAiApi.ChatCompletionRequest chatCompletionRequest = new OpenAiApi.ChatCompletionRequest(prompt.getInstructions().stream().map(message -> {
            return new OpenAiApi.ChatCompletionMessage(message.getContent(), OpenAiApi.ChatCompletionMessage.Role.valueOf(message.getMessageType().name()));
        }).toList(), Boolean.valueOf(z));
        if (prompt.getOptions() != null) {
            ChatOptions options = prompt.getOptions();
            if (!(options instanceof ChatOptions)) {
                throw new IllegalArgumentException("Prompt options are not of type ChatOptions: " + prompt.getOptions().getClass().getSimpleName());
            }
            OpenAiChatOptions openAiChatOptions = (OpenAiChatOptions) ModelOptionsUtils.copyToTarget(options, ChatOptions.class, OpenAiChatOptions.class);
            hashSet.addAll(handleFunctionCallbackConfigurations(openAiChatOptions, true));
            chatCompletionRequest = (OpenAiApi.ChatCompletionRequest) ModelOptionsUtils.merge(openAiChatOptions, chatCompletionRequest, OpenAiApi.ChatCompletionRequest.class);
        }
        if (this.defaultOptions != null) {
            hashSet.addAll(handleFunctionCallbackConfigurations(this.defaultOptions, false));
            chatCompletionRequest = (OpenAiApi.ChatCompletionRequest) ModelOptionsUtils.merge(chatCompletionRequest, this.defaultOptions, OpenAiApi.ChatCompletionRequest.class);
        }
        if (!CollectionUtils.isEmpty(hashSet)) {
            chatCompletionRequest = (OpenAiApi.ChatCompletionRequest) ModelOptionsUtils.merge(OpenAiChatOptions.builder().withTools(getFunctionTools(hashSet)).build(), chatCompletionRequest, OpenAiApi.ChatCompletionRequest.class);
        }
        return chatCompletionRequest;
    }

    private List<OpenAiApi.FunctionTool> getFunctionTools(Set<String> set) {
        return resolveFunctionCallbacks(set).stream().map(functionCallback -> {
            return new OpenAiApi.FunctionTool(new OpenAiApi.FunctionTool.Function(functionCallback.getDescription(), functionCallback.getName(), functionCallback.getInputTypeSchema()));
        }).toList();
    }

    protected OpenAiApi.ChatCompletionRequest doCreateToolResponseRequest(OpenAiApi.ChatCompletionRequest chatCompletionRequest, OpenAiApi.ChatCompletionMessage chatCompletionMessage, List<OpenAiApi.ChatCompletionMessage> list) {
        for (OpenAiApi.ChatCompletionMessage.ToolCall toolCall : chatCompletionMessage.toolCalls()) {
            String name = toolCall.function().name();
            String arguments = toolCall.function().arguments();
            if (!this.functionCallbackRegister.containsKey(name)) {
                throw new IllegalStateException("No function callback found for function name: " + name);
            }
            list.add(new OpenAiApi.ChatCompletionMessage(((FunctionCallback) this.functionCallbackRegister.get(name)).call(arguments), OpenAiApi.ChatCompletionMessage.Role.TOOL, name, toolCall.id(), null));
        }
        return (OpenAiApi.ChatCompletionRequest) ModelOptionsUtils.merge(new OpenAiApi.ChatCompletionRequest(list, false), chatCompletionRequest, OpenAiApi.ChatCompletionRequest.class);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public List<OpenAiApi.ChatCompletionMessage> doGetUserMessages(OpenAiApi.ChatCompletionRequest chatCompletionRequest) {
        return chatCompletionRequest.messages();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public OpenAiApi.ChatCompletionMessage doGetToolResponseMessage(ResponseEntity<OpenAiApi.ChatCompletion> responseEntity) {
        return ((OpenAiApi.ChatCompletion) responseEntity.getBody()).choices().iterator().next().message();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ResponseEntity<OpenAiApi.ChatCompletion> doChatCompletion(OpenAiApi.ChatCompletionRequest chatCompletionRequest) {
        return this.openAiApi.chatCompletionEntity(chatCompletionRequest);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean isToolFunctionCall(ResponseEntity<OpenAiApi.ChatCompletion> responseEntity) {
        OpenAiApi.ChatCompletion chatCompletion = (OpenAiApi.ChatCompletion) responseEntity.getBody();
        if (chatCompletion == null) {
            return false;
        }
        List<OpenAiApi.ChatCompletion.Choice> choices = chatCompletion.choices();
        if (CollectionUtils.isEmpty(choices)) {
            return false;
        }
        OpenAiApi.ChatCompletion.Choice choice = choices.get(0);
        return !CollectionUtils.isEmpty(choice.message().toolCalls()) && choice.finishReason() == OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS;
    }

    protected /* bridge */ /* synthetic */ Object doCreateToolResponseRequest(Object obj, Object obj2, List list) {
        return doCreateToolResponseRequest((OpenAiApi.ChatCompletionRequest) obj, (OpenAiApi.ChatCompletionMessage) obj2, (List<OpenAiApi.ChatCompletionMessage>) list);
    }
}
