「乾貨」如何從0寫一個服務網關?

「乾貨」如何從0寫一個服務網關?

java程序媛之家

一、引 言

什麼是網關?為什麼需要使用網關?

「乾貨」如何從0寫一個服務網關?

如圖所示,在不使用網關的情況下,我們的服務是直接暴露給服務調用方。當調用方增多,勢必需要添加定製化訪問權限、校驗等邏輯。當添加API網關後,再第三方調用端和服務提供方之間就創建了一面牆,這面牆直接與調用方通信進行權限控制。

本文所實現的網關源碼抄襲了---Oh,不對,是借鑑。借鑑了Zuul網關的源碼,提煉出其核心思路,實現了一套簡單的網關源碼,博主將其改名為Eatuul。

題外話

本文是業內能搜到的第一篇自己動手實現網關的文章。博主寫的手把手系列的文章,目的是在以最簡單的方式,揭露出中間件的核心原理,讓讀者能夠迅速瞭解實現的核心。需要說明的是,這不是源碼分析系列的文章,因此寫出來的代碼,省去了一些複雜的內容,畢竟大家能理解到該中間件的核心原理即可。如果想看源碼分析系列的,請關注博主,後期會將spring、spring boot、dubbo、mybatis等開源框架一一揭示。

二、正 文

設計思路

先大致說一下,就是定義一個Servlet接收請求。然後經過preFilter(封裝請求參數),routeFilter(轉發請求),postFilter(輸出內容)。三個過濾器之間,共享request、response以及其他的一些全局變量。如下圖所示

「乾貨」如何從0寫一個服務網關?

和真正的Zuul的區別?

主要區別有如下幾點

(1)Zuul中在異常處理模塊,有一個ErrorFilter來處理,博主在實現的時候偷懶了,略去。

(2)Zuul中PreFilters,RoutingFilters,PostFilters默認都實現了一組,具體如下表所示

「乾貨」如何從0寫一個服務網關?

博主總不可能每一個都給你們實現一遍吧。所以偷懶了,每種只實現一個。但是調用順序還是不變,按照PreFilters->RoutingFilters->PostFilters的順序調用

(3)在routeFilters確實有轉發請求的Filter,然而博主偷天換日了,改用RestTemplate實現.

代碼結構

大家去spring官網上搭建一套springboot的項目,博主就不展示pom的代碼了。直接將項目結構展示一下,如下圖所示

「乾貨」如何從0寫一個服務網關?

EatuulServlet.java

這個是網關的入口,邏輯也十分簡單,分為三步

(1)將request,response放入threadlocal中

(2)執行三組過濾器

(3)清除threadlocal中的的環境變量

源碼如下

package com.rjzheng.eatuul.http; 

import java.io.IOException;
import javax.servlet.ServletException;
import javax.servlet.annotation.WebServlet;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
@WebServlet(name = "eatuul", urlPatterns = "/*")
public class EatuulServlet extends HttpServlet {
private EatRunner eatRunner = new EatRunner();
@Override
public void service(HttpServletRequest req, HttpServletResponse resp)
throws ServletException, IOException {
//將request,和response放入上下文對象中
eatRunner.init(req, resp);
try {
//執行前置過濾
eatRunner.preRoute();
//執行過濾
eatRunner.route();
//執行後置過濾
eatRunner.postRoute();
} catch (Throwable e) {
RequestContext.getCurrentContext().getResponse()
.sendError(HttpServletResponse.SC_NOT_FOUND, e.getMessage());
} finally {
//清除變量
RequestContext.getCurrentContext().unset();
}
}
}

EatuulRunner.java

這個是具體的執行器。需要說明一下,在Zuul中,ZuulRunner在獲取具體有哪些過濾器的時候,有一個FileLoader可以動態讀取配置加載。博主在實現我們自己的EatuulRunner時候,略去動態讀取的過程,直接靜態寫死。

源碼如下

package com.rjzheng.eatuul.http;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import com.rjzheng.eatuul.filter.EatuulFilter;
import com.rjzheng.eatuul.filter.post.SendResponseFilter;
import com.rjzheng.eatuul.filter.pre.RequestWrapperFilter;
import com.rjzheng.eatuul.filter.route.RoutingFilter;
public class EatRunner {
//靜態寫死過濾器
private ConcurrentHashMap> hashFiltersByType = new ConcurrentHashMap>(){{
put("pre",new ArrayList(){{
add(new RequestWrapperFilter());
}});
put("route",new ArrayList(){{
add(new RoutingFilter());
}});
put("post",new ArrayList(){{
add(new SendResponseFilter());
}});
}};
public void init(HttpServletRequest req, HttpServletResponse resp) {
RequestContext ctx = RequestContext.getCurrentContext();
ctx.setRequest(req);
ctx.setResponse(resp);
}
public void preRoute() throws Throwable {
runFilters("pre");
}
public void route() throws Throwable{
runFilters("route");
}
public void postRoute() throws Throwable{
runFilters("post");
}
public void runFilters(String sType) throws Throwable {
List list = this.hashFiltersByType.get(sType);
if (list != null) {
for (int i = 0; i < list.size(); i++) {
EatuulFilter zuulFilter = list.get(i);
zuulFilter.run();

}
}
}
}

EatuulFilter.java

接下來就是一系列Filter的代碼了,先上父類EatuulFilter的源碼

package com.rjzheng.eatuul.filter;
public abstract class EatuulFilter {
abstract public String filterType();
abstract public int filterOrder();
abstract public void run();
}

RequestWrapperFilter.java

這個是PreFilter,前置執行過濾器,負責封裝請求。步驟如下所示

(1)封裝請求頭

(2)封裝請求體

(3)構造出RestTemplate能識別的RequestEntity

(4)將RequestEntity放入全局threadlocal之中

代碼如下所示

package com.rjzheng.eatuul.filter.pre; 

import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Collections;
import java.util.List;
import javax.servlet.http.HttpServletRequest;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.RequestEntity;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StreamUtils;
import com.rjzheng.eatuul.filter.EatuulFilter;
import com.rjzheng.eatuul.http.RequestContext;
public class RequestWrapperFilter extends EatuulFilter{
@Override
public String filterType() {
// TODO Auto-generated method stub
return "pre";
}
@Override
public int filterOrder() {
// TODO Auto-generated method stub
return -1;
}
@Override
public void run() {
String rootURL = "http://localhost:9090";
RequestContext ctx =RequestContext.getCurrentContext();
HttpServletRequest servletRequest = ctx.getRequest();
String targetURL = rootURL + servletRequest.getRequestURI();
RequestEntity requestEntity = null;
try {
requestEntity = createRequestEntity(servletRequest, targetURL);
} catch (Exception e) {
e.printStackTrace();
}
//4、將requestEntity放入全局threadlocal之中
ctx.setRequestEntity(requestEntity);
}
private RequestEntity createRequestEntity(HttpServletRequest request,String url) throws URISyntaxException, IOException {
String method = request.getMethod();
HttpMethod httpMethod = HttpMethod.resolve(method);
//1、封裝請求頭
MultiValueMap headers =createRequestHeaders(request);
//2、封裝請求體

byte[] body = createRequestBody(request);
//3、構造出RestTemplate能識別的RequestEntity
RequestEntity requestEntity = new RequestEntity(body,headers,httpMethod, new URI(url));
return requestEntity;
}
private byte[] createRequestBody(HttpServletRequest request) throws IOException {
InputStream inputStream = request.getInputStream();
return StreamUtils.copyToByteArray(inputStream);
}
private MultiValueMap createRequestHeaders(HttpServletRequest request) {
HttpHeaders headers = new HttpHeaders();
List headerNames = Collections.list(request.getHeaderNames());
for(String headerName:headerNames) {
List headerValues = Collections.list(request.getHeaders(headerName));
for(String headerValue:headerValues) {
headers.add(headerName, headerValue);
}
}
return headers;
}
}

RoutingFilter.java

這個是routeFilter,這裡我偷懶了,直接做轉發請求,並且將返回值ResponseEntity放入全局threadlocal中

package com.rjzheng.eatuul.filter.route;
import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.web.client.RestTemplate;
import com.rjzheng.eatuul.filter.EatuulFilter;
import com.rjzheng.eatuul.http.RequestContext;
public class RoutingFilter extends EatuulFilter{
@Override

public String filterType() {
// TODO Auto-generated method stub
return "route";
}
@Override
public int filterOrder() {
// TODO Auto-generated method stub
return 0;
}
@Override
public void run(){
RequestContext ctx = RequestContext.getCurrentContext();
RequestEntity requestEntity = ctx.getRequestEntity();
RestTemplate restTemplate = new RestTemplate();
ResponseEntity responseEntity = restTemplate.exchange(requestEntity,byte[].class);
ctx.setResponseEntity(responseEntity);
}
}

SendResponseFilter.java

這個是postFilters,將ResponseEntity輸出即可

package com.rjzheng.eatuul.filter.post;
import java.util.List;
import java.util.Map;
import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServletResponse;
import org.springframework.http.HttpHeaders;
import org.springframework.http.ResponseEntity;
import com.rjzheng.eatuul.filter.EatuulFilter;
import com.rjzheng.eatuul.http.RequestContext;
public class SendResponseFilter extends EatuulFilter{
@Override
public String filterType() {
return "post";
}
@Override
public int filterOrder() {
return 1000;
}
@Override
public void run() {
try {
addResponseHeaders();
writeResponse();
} catch (Exception e) {
e.printStackTrace();
}

}
private void addResponseHeaders() {
RequestContext ctx = RequestContext.getCurrentContext();
HttpServletResponse servletResponse = ctx.getResponse();
ResponseEntity responseEntity = ctx.getResponseEntity();
HttpHeaders httpHeaders = responseEntity.getHeaders();
for(Map.Entry> entry:httpHeaders.entrySet()) {
String headerName = entry.getKey();
List headerValues = entry.getValue();
for(String headerValue:headerValues) {
servletResponse.addHeader(headerName, headerValue);
}
}
}
private void writeResponse()throws Exception {
RequestContext ctx = RequestContext.getCurrentContext();
HttpServletResponse servletResponse = ctx.getResponse();
if (servletResponse.getCharacterEncoding() == null) { // only set if not set
servletResponse.setCharacterEncoding("UTF-8");
}
ResponseEntity responseEntity = ctx.getResponseEntity();
if(responseEntity.hasBody()) {
byte[] body = (byte[]) responseEntity.getBody();
ServletOutputStream outputStream = servletResponse.getOutputStream();
outputStream.write(body);
outputStream.flush();
}
}
}

RequestContext.java

最後是一直在說的全局threadlocal變量

package com.rjzheng.eatuul.http;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
public class RequestContext extends ConcurrentHashMap {

protected static Class extends RequestContext> contextClass = RequestContext.class;
protected static final ThreadLocal extends RequestContext> threadLocal = new ThreadLocal() {
@Override
protected RequestContext initialValue() {
try {
return contextClass.newInstance();
} catch (Throwable e) {
throw new RuntimeException(e);
}
}
};
public static RequestContext getCurrentContext() {
RequestContext context = threadLocal.get();
return context;
}
public HttpServletRequest getRequest() {
return (HttpServletRequest) get("request");
}
public void setRequest(HttpServletRequest request) {
put("request", request);
}
public HttpServletResponse getResponse() {
return (HttpServletResponse) get("response");
}
public void setResponse(HttpServletResponse response) {
set("response", response);
}
public void setRequestEntity(RequestEntity requestEntity){
set("requestEntity",requestEntity);
}
public RequestEntity getRequestEntity() {
return (RequestEntity) get("requestEntity");
}
public void setResponseEntity(ResponseEntity responseEntity){
set("responseEntity",responseEntity);
}
public ResponseEntity getResponseEntity() {
return (ResponseEntity) get("responseEntity");
}
public void set(String key, Object value) {
if (value != null)
put(key, value);
else
remove(key);
}
public void unset() {
threadLocal.remove();
}
}

如何測試?

自己另外起一個server端口為9090如下所示

package com.rjzheng.eatservice;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.builder.SpringApplicationBuilder;
import org.springframework.boot.web.servlet.ServletComponentScan;
import com.rjzheng.eatservice.controller.IndexController;
@SpringBootApplication
@ServletComponentScan(basePackageClasses = IndexController.class)
public class Application {
public static void main(String[] args) {
new SpringApplicationBuilder(Application.class).properties("server.port=9090").run(args);
}
}

再來一個controller

package com.rjzheng.eatservice.controller;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
@RestController
public class IndexController {
@RequestMapping("/index")
public String index() {
return "hello!world";
}
}

然後,你就發現可以從localhost:8080/index進行跳轉訪問了

三、結 論

本文模擬了一下zuul網關的源碼,借鑑了一下其精髓的部分。希望大家能有所收穫

--(完) --

看完本文有收穫?請轉發分享給更多人

關注「java程序媛之家」,提升Java技能

「乾貨」如何從0寫一個服務網關?

「乾貨」如何從0寫一個服務網關?


分享到:


相關文章: