package com.monkeylessey.service.impl;
|
|
import cn.hutool.http.HttpUtil;
|
import com.monkeylessey.config.AIConfig;
|
import com.monkeylessey.domain.form.ChatForm;
|
import com.monkeylessey.service.ChatService;
|
import lombok.RequiredArgsConstructor;
|
import lombok.extern.slf4j.Slf4j;
|
import org.springframework.http.MediaType;
|
import org.springframework.stereotype.Service;
|
import org.springframework.web.reactive.function.client.WebClient;
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
|
import javax.servlet.http.HttpServletResponse;
|
import java.io.IOException;
|
import java.util.HashMap;
|
import java.util.Map;
|
import java.util.concurrent.CompletableFuture;
|
|
/**
|
* @author:xp
|
* @date:2025/4/18 14:22
|
*/
|
@Slf4j
|
@Service
|
@RequiredArgsConstructor
|
public class ChatServiceImpl implements ChatService {
|
|
private final AIConfig aiConfig;
|
|
@Override
|
public SseEmitter sendMsg(ChatForm form) {
|
|
SseEmitter emitter = new SseEmitter(Long.MAX_VALUE); // 设置无限超时
|
|
// 1. 构建请求体
|
Map<String, Object> body = new HashMap<>();
|
body.put("query", form.getQuery());
|
body.put("mode", form.getMode());
|
body.put("kb_name", form.getKbName());
|
body.put("top_k", form.getTopK());
|
body.put("score_threshold", form.getScoreThreshold());
|
body.put("history", form.getHistory());
|
body.put("stream", form.getStream());
|
body.put("model", form.getModel());
|
body.put("temperature", form.getTemperature());
|
body.put("max_tokens", form.getMaxTokens());
|
body.put("prompt_name", form.getPromptName());
|
body.put("return_direct", form.getReturnDirect());
|
|
// 2. 异步处理SSE转发
|
CompletableFuture.runAsync(() -> {
|
try {
|
WebClient client = WebClient.create(aiConfig.getFullDomain());
|
client.post()
|
.uri("/chat/kb_chat")
|
.contentType(MediaType.APPLICATION_JSON)
|
.accept(MediaType.TEXT_EVENT_STREAM) // 声明接受SSE响应
|
.bodyValue(body)
|
.retrieve()
|
.bodyToFlux(String.class) // 使用Flux接收流式响应
|
.subscribe(
|
data -> {
|
try {
|
emitter.send(SseEmitter.event().data(data));
|
} catch (IOException e) {
|
log.error("发送失败", e.getMessage());
|
}
|
},
|
error -> emitter.completeWithError(error),
|
() -> emitter.complete()
|
);
|
} catch (Exception e) {
|
emitter.completeWithError(e);
|
}
|
});
|
|
// 3. 连接生命周期回调
|
emitter.onCompletion(() -> log.info("SSE connection completed"));
|
emitter.onTimeout(() -> log.warn("SSE connection timed out"));
|
emitter.onError(ex -> log.error("SSE error: {}", ex.getMessage()));
|
return emitter;
|
}
|
}
|