package com.geoway.sso.client;

import com.geoway.sso.client.constant.SsoConstant;
import com.geoway.sso.client.filter.ClientFilter;
import com.geoway.sso.client.filter.ParamFilter;

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/**
 * @author ALMJ
 * @desc 单点登录客户端过滤器容器中心
 */
public class NsSsoContainer extends ParamFilter implements Filter {
    /**
     * 排除URL(只对这里配置的URL放行，其他的一律拦截)
     */
    protected String excludeUrls;

    /**
     * 包含URL(只对这里配置的url拦截，其他的一律放行)
     */
    protected String includeUrls;

    private ClientFilter[] filters;

    @Override
    public void init(FilterConfig filterConfig) throws ServletException {
        if (filters == null || filters.length == 0) {
            throw new IllegalArgumentException("filters不能为空");
        }
        for (ClientFilter filter : filters) {
            filter.setAppId(getAppId());
            filter.setAppSecret(getAppSecret());
            filter.setServerUrl(getServerUrl());
            filter.setServerHttpsUrl(getServerHttpsUrl());
            filter.setLoginApi(getLoginApi());
            filter.init(filterConfig);
        }

    }

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
            throws IOException, ServletException {
        HttpServletRequest httpRequest = (HttpServletRequest) request;
        String reqUrl = httpRequest.getServletPath();
        String logoutReqHeader = httpRequest.getHeader(SsoConstant.LOGOUT_PARAMETER_NAME);
        if (isExcludeUrl(reqUrl) && logoutReqHeader == null) {
            //在排除规则中则放行，进入下一个请求的拦截
            chain.doFilter(request, response);
            return;
        }
        //不排除的请求一律调用过滤器isAccessAllowed方法去判断其可访问性
        HttpServletResponse httpResponse = (HttpServletResponse) response;
        for (ClientFilter filter : filters) {
            if (!filter.isAccessAllowed(httpRequest, httpResponse)) {
                return;
            }
        }
        chain.doFilter(request, response);
    }

    private boolean isExcludeUrl(String url) {
        if (excludeUrls == null || excludeUrls.isEmpty()) {
            return false;
        }
        return matchUrls(excludeUrls, url);
    }

    private boolean isIncludeUrl(String url) {
        if (includeUrls == null || includeUrls.isEmpty()) {
            return false;
        }
        return matchUrls(includeUrls, url);
    }

    @Override
    public void destroy() {
        if (filters == null || filters.length == 0) {
            return;
        }
        for (ClientFilter filter : filters) {
            filter.destroy();
        }
    }

    public void setExcludeUrls(String excludeUrls) {
        this.excludeUrls = excludeUrls;
    }

    public void setIncludeUrls(String includeUrls) {
        this.includeUrls = includeUrls;
    }

    public void setFilters(ClientFilter... filters) {
        this.filters = filters;
    }

    /**
     * @param urlsString 排除、包含规则字符串
     * @param url        当前请求url
     * @return 包含在中true, 否则false
     * @desc 匹配排除、包含url
     */
    private boolean matchUrls(String urlsString, String url) {
        Map<Boolean, List<String>> map = Arrays.stream(urlsString.split(","))
                .collect(Collectors.partitioningBy(u -> u.endsWith(SsoConstant.URL_FUZZY_MATCH)));
        //精准匹配
        List<String> urlList = map.get(false);
        for (String fullUrl : urlList) {
            if (fullUrl.trim().equals(url)) {
                return true;
            }
        }
        //模糊匹配
        urlList = map.get(true);
        // 再进行模糊匹配
        for (String matchUrl : urlList) {
            if (url.startsWith(matchUrl.trim().replace(SsoConstant.URL_FUZZY_MATCH, ""))) {
                return true;
            }
        }
        return false;
    }
}