/*
 * Decompiled with CFR 0.152.
 */
package com.geoway.ns.ai.base.tool.http;

import cn.hutool.core.thread.ThreadUtil;
import com.alibaba.fastjson.JSON;
import com.geoway.ns.ai.base.AiChatClientUtil;
import com.geoway.ns.ai.base.chat.client.AiChatClient;
import com.geoway.ns.ai.base.chat.dto.AiChatDTO;
import com.geoway.ns.ai.base.chat.message.AIMessageType;
import com.geoway.ns.ai.base.chat.message.AiHttpResultMessage;
import com.geoway.ns.ai.base.chat.message.AiMessage;
import com.geoway.ns.ai.base.tool.AiTool;
import com.geoway.ns.ai.base.tool.AiToolContext;
import com.geoway.ns.ai.base.tool.AiToolResult;
import com.geoway.ns.ai.base.tool.http.AiHttpToolCallRequest;
import com.geoway.ns.ai.base.tool.http.AiHttpToolCallResult;
import com.geoway.ns.ai.base.tool.http.AiHttpToolDefinitionResponse;
import com.geoway.ns.ai.base.tool.http.AiHttpToolExcuteService;
import java.util.ArrayList;
import java.util.function.Consumer;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.stereotype.Component;
import reactor.core.publisher.Flux;

@Component
public class AiHttpToolExcuteServiceImpl
implements AiHttpToolExcuteService {
    @Override
    public AiHttpToolDefinitionResponse getDefinition(String classname) {
        try {
            Class<?> clazz = Class.forName(classname);
            AiTool aiTool = (AiTool)clazz.newInstance();
            AiHttpToolDefinitionResponse definition = new AiHttpToolDefinitionResponse(aiTool.getToolDefinition());
            return definition;
        }
        catch (Exception e) {
            throw new RuntimeException(String.format("\u8fdc\u7a0bAI\u5de5\u5177\u3010%s\u3011\u521d\u59cb\u5316\u6709\u8bef\uff0c\u8bf7\u6838\u67e5", classname), e);
        }
    }

    @Override
    public AiHttpToolCallResult call(AiHttpToolCallRequest toolCallRequest) {
        try {
            Class<?> clazz = Class.forName(toolCallRequest.getClassName());
            AiTool aiTool = (AiTool)clazz.newInstance();
            AiToolContext aiToolContext = new AiToolContext(toolCallRequest.getToolContext());
            AiChatDTO aiChatDTO = aiToolContext.get("TOOL_CALL_CHAT_DTO", AiChatDTO.class);
            if (aiChatDTO != null) {
                AiChatClient aiChatClient = AiChatClientUtil.getChatClient(aiChatDTO);
                aiToolContext.setChatClient(aiChatClient);
            }
            AiToolResult aiToolResult = aiTool.call(toolCallRequest.getToolInput(), aiToolContext, null);
            AiHttpToolCallResult aiHttpToolCallResult = new AiHttpToolCallResult();
            aiHttpToolCallResult.setModelResult(aiToolResult.getModelResult());
            aiHttpToolCallResult.setParam(JSON.toJSONString((Object)aiToolResult.getToolParam()));
            aiHttpToolCallResult.setToolCallResult(JSON.toJSONString((Object)aiToolResult.getCallResult()));
            return aiHttpToolCallResult;
        }
        catch (Exception e) {
            throw new RuntimeException(String.format("\u8fdc\u7a0bAI\u5de5\u5177\u3010%s\u3011\u6267\u884c\u6709\u8bef\uff0c\u8bf7\u6838\u67e5", toolCallRequest.getClassName()), e);
        }
    }

    protected AiToolResult chat(AiHttpToolCallRequest aiHttpToolCallRequest, Consumer<AiMessage> messageConsumer) {
        try {
            Class<?> clazz = Class.forName(aiHttpToolCallRequest.getClassName());
            AiTool aiTool = (AiTool)clazz.newInstance();
            AiToolContext aiToolContext = new AiToolContext(aiHttpToolCallRequest.getToolContext());
            AiChatDTO aiChatDTO = aiToolContext.get("TOOL_CALL_CHAT_DTO", AiChatDTO.class);
            if (aiChatDTO != null) {
                AiChatClient aiChatClient = AiChatClientUtil.getChatClient(aiChatDTO);
                aiToolContext.setChatClient(aiChatClient);
            }
            AiToolResult aiToolResult = aiTool.call(aiHttpToolCallRequest.getToolInput(), aiToolContext, messageConsumer);
            return aiToolResult;
        }
        catch (Exception exception) {
            throw new RuntimeException(exception);
        }
    }

    @Override
    public Flux<ServerSentEvent<AiMessage>> callStream(AiHttpToolCallRequest aiHttpToolCallRequest) {
        try {
            return Flux.create(sink -> ThreadUtil.execute(() -> {
                AiToolContext aiToolContext = new AiToolContext(aiHttpToolCallRequest.getToolContext());
                String sessionId = aiToolContext.get("TOOL_CALL_CHAT_SESSIONID", String.class);
                ArrayList<AiMessage> aiMessages = new ArrayList<AiMessage>();
                try {
                    AiToolResult aiToolResult = this.chat(aiHttpToolCallRequest, aiMessage -> {
                        aiMessage.setSessionId(sessionId);
                        sink.next(aiMessage);
                    });
                    AiHttpToolCallResult aiHttpToolCallResult = new AiHttpToolCallResult();
                    aiHttpToolCallResult.setModelResult(aiToolResult.getModelResult());
                    aiHttpToolCallResult.setParam(JSON.toJSONString((Object)aiToolResult.getToolParam()));
                    aiHttpToolCallResult.setToolContext(aiToolContext.getContext());
                    aiHttpToolCallResult.setToolCallResult(JSON.toJSONString((Object)aiToolResult.getCallResult()));
                    aiHttpToolCallResult.setIsReturn(aiToolResult.getReturn());
                    AiHttpResultMessage aiHttpResultMessage = new AiHttpResultMessage(aiHttpToolCallResult);
                    aiHttpResultMessage.setSessionId(sessionId);
                    sink.next((Object)aiHttpResultMessage);
                }
                catch (Exception exception) {
                    exception.printStackTrace();
                    AiMessage errorMessage = new AiMessage();
                    errorMessage.setSessionId(sessionId);
                    errorMessage.setContent(exception.getMessage());
                    errorMessage.setRole(AIMessageType.Assistant.getDesc());
                    errorMessage.setState(-1);
                    aiMessages.add(errorMessage);
                    sink.next((Object)errorMessage);
                }
                finally {
                    sink.complete();
                }
            })).map(data -> ServerSentEvent.builder().data(data).build()).onErrorResume(F -> {
                F.printStackTrace();
                throw new RuntimeException((Throwable)F);
            });
        }
        catch (Exception e) {
            throw new RuntimeException(String.format("\u8fdc\u7a0bAI\u5de5\u5177\u3010%s\u3011\u6267\u884c\u6709\u8bef\uff0c\u8bf7\u6838\u67e5", aiHttpToolCallRequest.getClassName()), e);
        }
    }
}

