通过一个完整的java示例来演示如何通过okhttp来调用远程的sse流式接口
背景:我们有一个智能AI的聊天界面,需要调用三方厂商的大模型chat接口,返回答案(因为AI去理解并检索你的问题的时候这个是比较耗时的,这个时候客户端需要同步的在等待最终结果),所以我们的方案是通过流的方式把结果陆续的返回给客户端,这样能极大的提高用户的体验
1.引入相关依赖
<dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-web</artifactId> </dependency> <dependency> <groupId>org.projectlombok</groupId> <artifactId>lombok</artifactId> <optional>true</optional> </dependency> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-test</artifactId> <scope>test</scope> </dependency> <dependency> <groupId>com.squareup.okhttp3</groupId> <artifactId>okhttp</artifactId> <version>4.2.0</version> </dependency> <dependency> <groupId>com.squareup.okhttp3</groupId> <artifactId>okhttp-sse</artifactId> <version>4.2.0</version> </dependency> <dependency> <groupId>io.jsonwebtoken</groupId> <artifactId>jjwt</artifactId> <version>0.9.1</version> </dependency> <dependency> <groupId>com.alibaba</groupId> <artifactId>fastjson</artifactId> <version>1.2.78</version> </dependency>
2. controller
package com.demo.controller; import com.alibaba.fastjson.JSON; import com.demo.listener.SSEListener; import com.demo.params.req.ChatGlmDto; import com.demo.utils.ApiTokenUtil; import com.demo.utils.ExecuteSSEUtil; import lombok.extern.slf4j.Slf4j; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RestController; import javax.servlet.http.HttpServletResponse; @RestController @Slf4j public class APITestController { private static final String API_KEY = "xxxx"; private static final String URL = "xxx"; @PostMapping(value = "/sse-invoke", produces = "text/event-stream;charset=UTF-8") public void sse(@RequestBody ChatGlmDto chatGlmDto, HttpServletResponse rp) { try { String token = ApiTokenUtil.generateClientToken(API_KEY); SSEListener sseListener = new SSEListener(chatGlmDto, rp); ExecuteSSEUtil.executeSSE(URL, token, sseListener, JSON.toJSONString(chatGlmDto)); } catch (Exception e) { log.error("请求SSE错误处理", e); } } }
3. 监听器
监听器里的事件可以自己定义,然后自己去实现自己相关的业务逻辑,onEvent主要用来接收消息
package com.demo.listener; import com.alibaba.fastjson.JSON; import com.demo.params.req.ChatGlmDto; import lombok.Data; import lombok.extern.slf4j.Slf4j; import okhttp3.Response; import okhttp3.sse.EventSource; import okhttp3.sse.EventSourceListener; import javax.servlet.http.HttpServletResponse; import java.util.concurrent.CountDownLatch; @Slf4j @Data public class SSEListener extends EventSourceListener { private CountDownLatch countDownLatch = new CountDownLatch(1); private ChatGlmDto chatGlmDto; private HttpServletResponse rp; private StringBuffer output = new StringBuffer(); public SSEListener(ChatGlmDto chatGlmDto, HttpServletResponse response) { this.chatGlmDto = chatGlmDto; this.rp = response; } /** * {@inheritDoc} * 建立sse连接 */ @Override public void onOpen(final EventSource eventSource, final Response response) { if (rp != null) { rp.setContentType("text/event-stream"); rp.setCharacterEncoding("UTF-8"); rp.setStatus(200); log.info("建立sse连接..." + JSON.toJSONString(chatGlmDto)); } else { log.info("客户端非sse推送" + JSON.toJSONString(chatGlmDto)); } } /** * 事件 * * @param eventSource * @param id * @param type * @param data */ @Override public void onEvent(EventSource eventSource, String id, String type, String data) { try { output.append(data); if ("finish".equals(type)) { log.info("请求结束{} {}", chatGlmDto.getMessageId(), output.toString()); } if ("error".equals(type)) { log.info("{}: {}source {}", chatGlmDto.getMessageId(), data, JSON.toJSONString(chatGlmDto)); } if (rp != null) { if ("\n".equals(data)) { rp.getWriter().write("event:" + type + "\n"); rp.getWriter().write("id:" + chatGlmDto.getMessageId() + "\n"); rp.getWriter().write("data:\n\n"); rp.getWriter().flush(); } else { String[] dataArr = data.split("\\n"); for (int i = 0; i < dataArr.length; i++) { if (i == 0) { rp.getWriter().write("event:" + type + "\n"); rp.getWriter().write("id:" + chatGlmDto.getMessageId() + "\n"); } if (i == dataArr.length - 1) { rp.getWriter().write("data:" + dataArr[i] + "\n\n"); rp.getWriter().flush(); } else { rp.getWriter().write("data:" + dataArr[i] + "\n"); rp.getWriter().flush(); } } } } } catch (Exception e) { log.error("消息错误[" + JSON.toJSONString(chatGlmDto) + "]", e); countDownLatch.countDown(); throw new RuntimeException(e); } } /** * {@inheritDoc} */ @Override public void onClosed(final EventSource eventSource) { log.info("sse连接关闭:{}", chatGlmDto.getMessageId()); log.info("结果输出:{}" + output.toString()); countDownLatch.countDown(); } /** * {@inheritDoc} */ @Override public void onFailure(final EventSource eventSource, final Throwable t, final Response response) { log.error("使用事件源时出现异常... [响应:{}]...", chatGlmDto.getMessageId()); countDownLatch.countDown(); } public CountDownLatch getCountDownLatch() { return this.countDownLatch; } }
4. 相关工具类
获取token ApiTokenUtil类,这个根据自己的业务需求看是否需要,我这里为了程序能跑起来,就保留了
package com.demo.utils; import com.alibaba.fastjson.JSON; import io.jsonwebtoken.Jwts; import io.jsonwebtoken.SignatureAlgorithm; import java.nio.charset.StandardCharsets; import java.util.HashMap; import java.util.Map; public class ApiTokenUtil { public static String generateClientToken(String apikey) { String[] apiKeyParts = apikey.split("\\."); String api_key = apiKeyParts[0]; String secret = apiKeyParts[1]; Map<String, Object> header = new HashMap<>(); header.put("alg", SignatureAlgorithm.HS256); header.put("sign_type", "SIGN"); Map<String, Object> payload = new HashMap<>(); payload.put("api_key", api_key); payload.put("exp", System.currentTimeMillis() + 5 * 600 * 1000); payload.put("timestamp", System.currentTimeMillis()); String token = null; try { token = Jwts.builder().setHeader(header) .setPayload(JSON.toJSONString(payload)) .signWith(SignatureAlgorithm.HS256, secret.getBytes(StandardCharsets.UTF_8)) .compact(); } catch (Exception e) { System.out.println(); } return token; } }
ExecuteSSEUtil 类
package com.demo.utils; import com.demo.listener.SSEListener; import lombok.extern.slf4j.Slf4j; import okhttp3.MediaType; import okhttp3.Request; import okhttp3.RequestBody; import okhttp3.sse.EventSource; import okhttp3.sse.EventSources; @Slf4j public class ExecuteSSEUtil { public static void executeSSE(String url, String authToken, SSEListener eventSourceListener, String chatGlm) throws Exception { RequestBody formBody = RequestBody.create(chatGlm, MediaType.parse("application/json; charset=utf-8")); Request.Builder requestBuilder = new Request.Builder(); requestBuilder.addHeader("Authorization", authToken); Request request = requestBuilder.url(url).post(formBody).build(); EventSource.Factory factory = EventSources.createFactory(OkHttpUtil.getInstance()); //创建事件 factory.newEventSource(request, eventSourceListener); eventSourceListener.getCountDownLatch().await(); } }
OkHttpUtil 类
package com.demo.utils; import okhttp3.ConnectionPool; import okhttp3.OkHttpClient; import java.net.Proxy; import java.util.concurrent.TimeUnit; public class OkHttpUtil { private static OkHttpClient okHttpClient; public static ConnectionPool connectionPool = new ConnectionPool(10, 5, TimeUnit.MINUTES); public static OkHttpClient getInstance() { if (okHttpClient == null) { //加同步安全 synchronized (OkHttpClient.class) { if (okHttpClient == null) { //okhttp可以缓存数据....指定缓存路径 okHttpClient = new OkHttpClient.Builder()//构建器 .proxy(Proxy.NO_PROXY) //来屏蔽系统代理 .connectionPool(connectionPool) .connectTimeout(600, TimeUnit.SECONDS)//连接超时 .writeTimeout(600, TimeUnit.SECONDS)//写入超时 .readTimeout(600, TimeUnit.SECONDS)//读取超时 .build(); okHttpClient.dispatcher().setMaxRequestsPerHost(200); okHttpClient.dispatcher().setMaxRequests(200); } } } return okHttpClient; } }
ChatGlmDto 请求实体类
package com.demo.params.req; import lombok.Data; /** * Created by WeiRan on 2023.03.20 19:19 */ @Data public class ChatGlmDto { private String messageId; private Object prompt; private String requestTaskNo; private boolean incremental = true; private boolean notSensitive = true; }
5. 接口调用调试
我这里就直接使用curl命令来调用了
curl 'http://localhost:8080/sse-invoke' --data '{"prompt":[{"role":"user","content":"泰山有多高?"}]}' -H 'Content-Type: application/json'
返回结果:
分割线---------------------------------------------------------------------------------------------------------------------------------
创作不易,三连支持一下吧 👍
最后送大家一句话白驹过隙,沧海桑田