package org.springframework.ai.anthropic;

import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
import java.util.ArrayList;
import java.util.Base64;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.anthropic.api.AnthropicApi;
import org.springframework.ai.anthropic.metadata.AnthropicUsage;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.MessageAggregator;
import org.springframework.ai.chat.observation.ChatModelObservationContext;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

/* loaded from: input_file:org/springframework/ai/anthropic/AnthropicChatModel.class */
public class AnthropicChatModel extends AbstractToolCallSupport implements ChatModel {
    private static final Logger logger = LoggerFactory.getLogger(AnthropicChatModel.class);
    private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();
    public static final String DEFAULT_MODEL_NAME = AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getValue();
    public static final Integer DEFAULT_MAX_TOKENS = 500;
    public static final Float DEFAULT_TEMPERATURE = Float.valueOf(0.8f);
    public final AnthropicApi anthropicApi;
    private final AnthropicChatOptions defaultOptions;
    public final RetryTemplate retryTemplate;
    private final ObservationRegistry observationRegistry;
    private ChatModelObservationConvention observationConvention;

    public AnthropicChatModel(AnthropicApi anthropicApi) {
        this(anthropicApi, AnthropicChatOptions.builder().withModel(DEFAULT_MODEL_NAME).withMaxTokens(DEFAULT_MAX_TOKENS).withTemperature(DEFAULT_TEMPERATURE).build());
    }

    public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions anthropicChatOptions) {
        this(anthropicApi, anthropicChatOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE);
    }

    public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions anthropicChatOptions, RetryTemplate retryTemplate) {
        this(anthropicApi, anthropicChatOptions, retryTemplate, null);
    }

    public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions anthropicChatOptions, RetryTemplate retryTemplate, FunctionCallbackContext functionCallbackContext) {
        this(anthropicApi, anthropicChatOptions, retryTemplate, functionCallbackContext, List.of());
    }

    public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions anthropicChatOptions, RetryTemplate retryTemplate, FunctionCallbackContext functionCallbackContext, List<FunctionCallback> list) {
        this(anthropicApi, anthropicChatOptions, retryTemplate, functionCallbackContext, list, ObservationRegistry.NOOP);
    }

    public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions anthropicChatOptions, RetryTemplate retryTemplate, FunctionCallbackContext functionCallbackContext, List<FunctionCallback> list, ObservationRegistry observationRegistry) {
        super(functionCallbackContext, anthropicChatOptions, list);
        this.observationConvention = DEFAULT_OBSERVATION_CONVENTION;
        Assert.notNull(anthropicApi, "AnthropicApi must not be null");
        Assert.notNull(anthropicChatOptions, "DefaultOptions must not be null");
        Assert.notNull(retryTemplate, "RetryTemplate must not be null");
        Assert.notNull(observationRegistry, "ObservationRegistry must not be null");
        this.anthropicApi = anthropicApi;
        this.defaultOptions = anthropicChatOptions;
        this.retryTemplate = retryTemplate;
        this.observationRegistry = observationRegistry;
    }

    public ChatResponse call(Prompt prompt) {
        AnthropicApi.ChatCompletionRequest createRequest = createRequest(prompt, false);
        ChatModelObservationContext build = ChatModelObservationContext.builder().prompt(prompt).provider(AnthropicApi.PROVIDER_NAME).requestOptions(buildRequestOptions(createRequest)).build();
        ChatResponse chatResponse = (ChatResponse) ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> {
            return build;
        }, this.observationRegistry).observe(() -> {
            ChatResponse chatResponse2 = toChatResponse((AnthropicApi.ChatCompletionResponse) ((ResponseEntity) this.retryTemplate.execute(retryContext -> {
                return this.anthropicApi.chatCompletionEntity(createRequest);
            })).getBody());
            build.setResponse(chatResponse2);
            return chatResponse2;
        });
        return (chatResponse == null || !isToolCall(chatResponse, Set.of("tool_use"))) ? chatResponse : call(new Prompt(handleToolCalls(prompt, chatResponse), prompt.getOptions()));
    }

    public Flux<ChatResponse> stream(Prompt prompt) {
        return Flux.deferContextual(contextView -> {
            AnthropicApi.ChatCompletionRequest createRequest = createRequest(prompt, true);
            ChatModelObservationContext build = ChatModelObservationContext.builder().prompt(prompt).provider(AnthropicApi.PROVIDER_NAME).requestOptions(buildRequestOptions(createRequest)).build();
            Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> {
                return build;
            }, this.observationRegistry);
            observation.parentObservation((Observation) contextView.getOrDefault("micrometer.observation", (Object) null)).start();
            Flux switchMap = this.anthropicApi.chatCompletionStream(createRequest).switchMap(chatCompletionResponse -> {
                ChatResponse chatResponse = toChatResponse(chatCompletionResponse);
                return isToolCall(chatResponse, Set.of("tool_use")) ? stream(new Prompt(handleToolCalls(prompt, chatResponse), prompt.getOptions())) : Mono.just(chatResponse);
            });
            Objects.requireNonNull(observation);
            Flux contextWrite = switchMap.doOnError(observation::error).doFinally(signalType -> {
                observation.stop();
            }).contextWrite(context -> {
                return context.put("micrometer.observation", observation);
            });
            MessageAggregator messageAggregator = new MessageAggregator();
            Objects.requireNonNull(build);
            return messageAggregator.aggregate(contextWrite, (v1) -> {
                r2.setResponse(v1);
            });
        });
    }

    private ChatResponse toChatResponse(AnthropicApi.ChatCompletionResponse chatCompletionResponse) {
        if (chatCompletionResponse == null) {
            logger.warn("Null chat completion returned");
            return new ChatResponse(List.of());
        }
        ArrayList arrayList = new ArrayList(chatCompletionResponse.content().stream().filter(contentBlock -> {
            return contentBlock.type() != AnthropicApi.ContentBlock.Type.TOOL_USE;
        }).map(contentBlock2 -> {
            return new Generation(new AssistantMessage(contentBlock2.text(), Map.of()), ChatGenerationMetadata.from(chatCompletionResponse.stopReason(), (Object) null));
        }).toList());
        List<AnthropicApi.ContentBlock> list = chatCompletionResponse.content().stream().filter(contentBlock3 -> {
            return contentBlock3.type() == AnthropicApi.ContentBlock.Type.TOOL_USE;
        }).toList();
        if (!CollectionUtils.isEmpty(list)) {
            ArrayList arrayList2 = new ArrayList();
            for (AnthropicApi.ContentBlock contentBlock4 : list) {
                arrayList2.add(new AssistantMessage.ToolCall(contentBlock4.id(), "function", contentBlock4.name(), ModelOptionsUtils.toJsonString(contentBlock4.input())));
            }
            arrayList.add(new Generation(new AssistantMessage("", Map.of(), arrayList2), ChatGenerationMetadata.from(chatCompletionResponse.stopReason(), (Object) null)));
        }
        return new ChatResponse(arrayList, from(chatCompletionResponse));
    }

    private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse chatCompletionResponse) {
        Assert.notNull(chatCompletionResponse, "Anthropic ChatCompletionResult must not be null");
        return ChatResponseMetadata.builder().withId(chatCompletionResponse.id()).withModel(chatCompletionResponse.model()).withUsage(AnthropicUsage.from(chatCompletionResponse.usage())).withKeyValue("stop-reason", chatCompletionResponse.stopReason()).withKeyValue("stop-sequence", chatCompletionResponse.stopSequence()).withKeyValue("type", chatCompletionResponse.type()).build();
    }

    private String fromMediaData(Object obj) {
        if (obj instanceof byte[]) {
            return Base64.getEncoder().encodeToString((byte[]) obj);
        }
        if (obj instanceof String) {
            return (String) obj;
        }
        throw new IllegalArgumentException("Unsupported media data type: " + obj.getClass().getSimpleName());
    }

    AnthropicApi.ChatCompletionRequest createRequest(Prompt prompt, boolean z) {
        HashSet hashSet = new HashSet();
        AnthropicApi.ChatCompletionRequest chatCompletionRequest = new AnthropicApi.ChatCompletionRequest(this.defaultOptions.getModel(), prompt.getInstructions().stream().filter(message -> {
            return message.getMessageType() != MessageType.SYSTEM;
        }).map(message2 -> {
            if (message2.getMessageType() == MessageType.USER) {
                ArrayList arrayList = new ArrayList(List.of(new AnthropicApi.ContentBlock(message2.getContent())));
                if (message2 instanceof UserMessage) {
                    UserMessage userMessage = (UserMessage) message2;
                    if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
                        arrayList.addAll(userMessage.getMedia().stream().map(media -> {
                            return new AnthropicApi.ContentBlock(media.getMimeType().toString(), fromMediaData(media.getData()));
                        }).toList());
                    }
                }
                return new AnthropicApi.AnthropicMessage(arrayList, AnthropicApi.Role.valueOf(message2.getMessageType().name()));
            }
            if (message2.getMessageType() != MessageType.ASSISTANT) {
                if (message2.getMessageType() == MessageType.TOOL) {
                    return new AnthropicApi.AnthropicMessage(((ToolResponseMessage) message2).getResponses().stream().map(toolResponse -> {
                        return new AnthropicApi.ContentBlock(AnthropicApi.ContentBlock.Type.TOOL_RESULT, toolResponse.id(), toolResponse.responseData());
                    }).toList(), AnthropicApi.Role.USER);
                }
                throw new IllegalArgumentException("Unsupported message type: " + message2.getMessageType());
            }
            AssistantMessage assistantMessage = (AssistantMessage) message2;
            ArrayList arrayList2 = new ArrayList();
            if (StringUtils.hasText(message2.getContent())) {
                arrayList2.add(new AnthropicApi.ContentBlock(message2.getContent()));
            }
            if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
                for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
                    arrayList2.add(new AnthropicApi.ContentBlock(AnthropicApi.ContentBlock.Type.TOOL_USE, toolCall.id(), toolCall.name(), (Map<String, Object>) ModelOptionsUtils.jsonToMap(toolCall.arguments())));
                }
            }
            return new AnthropicApi.AnthropicMessage(arrayList2, AnthropicApi.Role.ASSISTANT);
        }).toList(), (String) prompt.getInstructions().stream().filter(message3 -> {
            return message3.getMessageType() == MessageType.SYSTEM;
        }).map(message4 -> {
            return message4.getContent();
        }).collect(Collectors.joining(System.lineSeparator())), this.defaultOptions.getMaxTokens(), this.defaultOptions.getTemperature(), Boolean.valueOf(z));
        if (prompt.getOptions() != null) {
            AnthropicChatOptions anthropicChatOptions = (AnthropicChatOptions) ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, AnthropicChatOptions.class);
            hashSet.addAll(runtimeFunctionCallbackConfigurations(anthropicChatOptions));
            chatCompletionRequest = (AnthropicApi.ChatCompletionRequest) ModelOptionsUtils.merge(anthropicChatOptions, chatCompletionRequest, AnthropicApi.ChatCompletionRequest.class);
        }
        if (!CollectionUtils.isEmpty(this.defaultOptions.getFunctions())) {
            hashSet.addAll(this.defaultOptions.getFunctions());
        }
        AnthropicApi.ChatCompletionRequest chatCompletionRequest2 = (AnthropicApi.ChatCompletionRequest) ModelOptionsUtils.merge(chatCompletionRequest, this.defaultOptions, AnthropicApi.ChatCompletionRequest.class);
        if (!CollectionUtils.isEmpty(hashSet)) {
            chatCompletionRequest2 = AnthropicApi.ChatCompletionRequest.from(chatCompletionRequest2).withTools(getFunctionTools(hashSet)).build();
        }
        return chatCompletionRequest2;
    }

    private List<AnthropicApi.Tool> getFunctionTools(Set<String> set) {
        return resolveFunctionCallbacks(set).stream().map(functionCallback -> {
            return new AnthropicApi.Tool(functionCallback.getName(), functionCallback.getDescription(), ModelOptionsUtils.jsonToMap(functionCallback.getInputTypeSchema()));
        }).toList();
    }

    private ChatOptions buildRequestOptions(AnthropicApi.ChatCompletionRequest chatCompletionRequest) {
        return ChatOptionsBuilder.builder().withModel(chatCompletionRequest.model()).withMaxTokens(chatCompletionRequest.maxTokens()).withStopSequences(chatCompletionRequest.stopSequences()).withTemperature(chatCompletionRequest.temperature()).withTopK(chatCompletionRequest.topK()).withTopP(chatCompletionRequest.topP()).build();
    }

    public ChatOptions getDefaultOptions() {
        return AnthropicChatOptions.fromOptions(this.defaultOptions);
    }

    public void setObservationConvention(ChatModelObservationConvention chatModelObservationConvention) {
        Assert.notNull(chatModelObservationConvention, "observationConvention cannot be null");
        this.observationConvention = chatModelObservationConvention;
    }
}
