package com.geoway.sso.client.filter;

import cn.hutool.core.util.StrUtil;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.serializer.SerializerFeature;
import com.geoway.sso.client.constant.Oauth2Constant;
import com.geoway.sso.client.constant.SsoConstant;
import com.geoway.sso.client.rpc.Result;
import com.geoway.sso.client.rpc.RpcAccessToken;
import com.geoway.sso.client.util.Oauth2Utils;
import com.geoway.sso.client.util.SessionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import java.io.IOException;
import java.io.PrintWriter;

/**
 * @author ALMJ
 * @desc 单点登录Filter
 */
public class LoginFilter extends ClientFilter {
    private final Logger logger = LoggerFactory.getLogger(getClass());

    @Override
    public boolean isAccessAllowed(HttpServletRequest request, HttpServletResponse response) throws IOException {
        return true;
    }

    /**
     * 获取accessToken和用户信息存session
     *
     * @param code
     * @param request
     */
    public void getAccessToken(String code, HttpServletRequest request) {
        Result<RpcAccessToken> result = Oauth2Utils.getAccessToken(getServerUrl(), getAppId(),
                getAppSecret(), code);
        if (!result.isSuccess()) {
            logger.error("getAccessToken has error, message:{}", result.getMessage());
            return;
        }
        setAccessTokenInSession(result.getData(), request);
    }

    protected Result<RpcAccessToken> queryAccessToken(String accessToken, HttpServletRequest request) {
        Result<RpcAccessToken> result = Oauth2Utils.queryAccessToken(getServerUrl(), accessToken);
        return result;
    }

    /**
     * 通过refreshToken参数调用http请求延长服务端session，并返回新的accessToken
     *
     * @param refreshToken
     * @param request
     * @return
     */
    protected boolean refreshToken(String refreshToken, HttpServletRequest request) {
        Result<RpcAccessToken> result = Oauth2Utils.refreshToken(getServerUrl(), getAppId(), refreshToken);
        if (!result.isSuccess()) {
            logger.error("refreshToken has error, message:{}", result.getMessage());
            return false;
        }
        return setAccessTokenInSession(result.getData(), request);
    }

    protected boolean setAccessTokenInSession(RpcAccessToken rpcAccessToken, HttpServletRequest request) {
        if (rpcAccessToken == null) {
            return false;
        }
        // 记录accessToken到本地session
        SessionUtils.setAccessToken(request, rpcAccessToken);
        // 记录本地session和accessToken映射
        recordSession(request, rpcAccessToken.getAccessToken(), rpcAccessToken.getUser().getLoginName());
        return true;
    }

    private void recordSession(final HttpServletRequest request, String accessToken, String userName) {
        final HttpSession session = request.getSession();
        getSessionMappingStorage().removeBySessionById(session.getId());
        getSessionMappingStorage().addSessionById(accessToken, session);
        getSessionMappingStorage().removeTokenByUserName(userName);
        getSessionMappingStorage().addTokenByUserName(accessToken, userName);
    }


    /**
     * 去除返回地址中的票据参数
     *
     * @param request
     * @return
     * @throws IOException
     */
    public void redirectLocalRemoveCode(HttpServletRequest request, HttpServletResponse response) throws IOException {
        String currentUrl = getCurrentUrl(request);
        currentUrl = currentUrl.substring(0, currentUrl.indexOf(Oauth2Constant.AUTH_CODE) - 1);
        response.sendRedirect(currentUrl);
    }

    /**
     * 获取当前请求地址
     *
     * @param request
     * @return
     */
    private String getCurrentUrl(HttpServletRequest request) {
        return new StringBuilder().append(request.getRequestURL())
                .append(request.getQueryString() == null ? "" : "?" + request.getQueryString()).toString();
    }

    protected boolean isAjaxRequest(HttpServletRequest request) {
        String requestedWith = request.getHeader("X-Requested-With");
        return requestedWith != null ? "XMLHttpRequest".equals(requestedWith) : false;
    }

    /**
     * 将response转成json独享
     *
     * @param response
     * @param code
     * @param message
     * @throws IOException
     */
    protected void responseJson(HttpServletResponse response, int code, String message) throws IOException {
        response.setContentType("application/json;charset=UTF-8");
        response.setStatus(200);
        PrintWriter writer = response.getWriter();
        writer.write(JSON.toJSONString(Result.create(code, message)));
        writer.flush();
        writer.close();
    }

    protected void responseJson(HttpServletResponse response, Result result) throws IOException {
        response.setContentType("application/json;charset=UTF-8");
        response.setStatus(200);
        PrintWriter writer = response.getWriter();
        writer.write(JSON.toJSONString(result));
        writer.flush();
        writer.close();
    }

    //前后端分离登录拦截返回
    protected void markLoginResponse(HttpServletResponse response, String msg) throws IOException {
        response.setStatus(HttpStatus.OK.value());
        response.setContentType(MediaType.APPLICATION_JSON_VALUE);
        response.setCharacterEncoding("UTF-8");
        response.setHeader("Cache-Control", "no-cache, must-revalidate");
        JSONObject baseResopnse = new JSONObject();
        baseResopnse.put("code", SsoConstant.NO_LOGIN);
        baseResopnse.put("status",SsoConstant.RESPONSE_STATUS_LOGINOUT);
        baseResopnse.put("message","无效token或token已过期");
        if (StrUtil.isNotEmpty(msg)) {
            baseResopnse.put("message",msg);
        }
        String jsonResult = JSON.toJSONString(baseResopnse, SerializerFeature.WriteNullStringAsEmpty);
        response.getWriter().write(jsonResult);
    }
}