package com.geoway.web.interceptor;

import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.io.FileTypeUtil;
import cn.hutool.core.io.FileUtil;
import cn.hutool.core.text.CharSequenceUtil;
import cn.hutool.core.util.StrUtil;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.serializer.SerializerFeature;
import com.geoway.base.dto.BaseResponse;
import com.geoway.web.config.FileUploadConfiguration;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Component;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.multipart.MultipartHttpServletRequest;
import org.springframework.web.servlet.handler.HandlerInterceptorAdapter;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;

@Slf4j
@Component
public class FileTypeFilterInterceptor extends HandlerInterceptorAdapter {

   @Autowired
   FileUploadConfiguration fileTypeConfig;

    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler)
            throws Exception {

        if (!(request instanceof MultipartHttpServletRequest)) {
            return true;
        }
        List<String> allowFileTypes = fileTypeConfig.getAllowFileTypes();
        if(allowFileTypes.contains("all")){
            return true;
        }

        MultipartHttpServletRequest multipartHttpServletRequest = (MultipartHttpServletRequest) request;
        Map<String, MultipartFile> allFiles = multipartHttpServletRequest.getFileMap();
        List<String> notAllowedFiles = getNotAllowedFiles(allFiles,allowFileTypes,handler);

        if (CollUtil.isNotEmpty(notAllowedFiles)) {
            String message = CharSequenceUtil.format("不被允许的文件类型！不被允许的文件列表如下：{}", notAllowedFiles);
            markFailureResponse(response,message);
            return false;
        }

        return  true;
    }

    private List<String> getNotAllowedFiles(Map<String, MultipartFile> allFiles,List<String> allowFileTypes,Object handler) throws IOException {
        List<String> notAllowedFilenames = null;
        for (MultipartFile file : allFiles.values()) {
            if (!isAllowedFile(file, allowFileTypes,handler)) {
                if (notAllowedFilenames == null) {
                    notAllowedFilenames = new ArrayList<>();
                }
                notAllowedFilenames.add(file.getOriginalFilename());
            }
        }
        return notAllowedFilenames;
    }

    boolean isAllowedFile(MultipartFile file, List<String> allowFileTypes,Object handler) throws IOException {
        String fileName = file.getOriginalFilename();
        try (InputStream in = file.getInputStream()) {
            String type = FileTypeUtil.getType(in);
            if (StrUtil.isBlank(type)) {
                // 无法通过文件头或者扩展名识别文件类型。
                return false;
            }

            String extType = FileUtil.extName(fileName);
            if (StrUtil.isBlank(extType)) {
                // 后缀读不到类型，直接拒绝
                return false;
            }

            // 文件头类型和后缀类型都要满足白名单
            List<String> fileTypes = new ArrayList<>(allowFileTypes);
            fileTypes.add(type.toLowerCase(Locale.ROOT));
            fileTypes.add(extType.toLowerCase(Locale.ROOT));

//            if (CollectionUtils.containsAll(allowFileTypes, fileTypes)) {
//                return true;
//            }

            if (allowFileTypes.containsAll(fileTypes)) {
                return true;
            }
        }

        return false;
    }

    private void markFailureResponse(HttpServletResponse response,String msg) throws IOException {
        response.setStatus(HttpStatus.OK.value());
        response.setContentType(MediaType.APPLICATION_JSON_VALUE);
        response.setCharacterEncoding("UTF-8");
        response.setHeader("Cache-Control", "no-cache, must-revalidate");
        BaseResponse baseResopnse = new BaseResponse();
        baseResopnse.markFailure();
        baseResopnse.setMessage(msg);
        String jsonResult = JSON.toJSONString(baseResopnse, SerializerFeature.WriteNullStringAsEmpty);
        response.getWriter().write(jsonResult);
    }

}
