/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.security.authentication.server;

import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSObject;
import com.nimbusds.jose.crypto.RSASSAVerifier;
import com.nimbusds.jwt.SignedJWT;
import java.io.IOException;
import java.security.interfaces.RSAPublicKey;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.Properties;
import javax.servlet.ServletException;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.apache.hadoop.classification.VisibleForTesting;
import org.apache.hadoop.security.authentication.client.AuthenticationException;
import org.apache.hadoop.security.authentication.server.AltKerberosAuthenticationHandler;
import org.apache.hadoop.security.authentication.server.AuthenticationToken;
import org.apache.hadoop.security.authentication.util.CertificateUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class JWTRedirectAuthenticationHandler
extends AltKerberosAuthenticationHandler {
    private static Logger LOG = LoggerFactory.getLogger(JWTRedirectAuthenticationHandler.class);
    public static final String AUTHENTICATION_PROVIDER_URL = "authentication.provider.url";
    public static final String PUBLIC_KEY_PEM = "public.key.pem";
    public static final String EXPECTED_JWT_AUDIENCES = "expected.jwt.audiences";
    public static final String JWT_COOKIE_NAME = "jwt.cookie.name";
    private static final String ORIGINAL_URL_QUERY_PARAM = "originalUrl=";
    private String authenticationProviderUrl = null;
    private RSAPublicKey publicKey = null;
    private List<String> audiences = null;
    private String cookieName = "hadoop-jwt";

    public void setPublicKey(RSAPublicKey pk) {
        this.publicKey = pk;
    }

    @Override
    public void init(Properties config) throws ServletException {
        String customCookieName;
        String auds;
        super.init(config);
        this.authenticationProviderUrl = config.getProperty(AUTHENTICATION_PROVIDER_URL);
        if (this.authenticationProviderUrl == null) {
            throw new ServletException("Authentication provider URL must not be null - configure: authentication.provider.url");
        }
        if (this.publicKey == null) {
            String pemPublicKey = config.getProperty(PUBLIC_KEY_PEM);
            if (pemPublicKey == null) {
                throw new ServletException("Public key for signature validation must be provisioned.");
            }
            this.publicKey = CertificateUtil.parseRSAPublicKey(pemPublicKey);
        }
        if ((auds = config.getProperty(EXPECTED_JWT_AUDIENCES)) != null) {
            String[] audArray = auds.split(",");
            this.audiences = new ArrayList<String>();
            for (String a : audArray) {
                this.audiences.add(a);
            }
        }
        if ((customCookieName = config.getProperty(JWT_COOKIE_NAME)) != null) {
            this.cookieName = customCookieName;
        }
    }

    @Override
    public AuthenticationToken alternateAuthenticate(HttpServletRequest request, HttpServletResponse response) throws IOException, AuthenticationException {
        AuthenticationToken token = null;
        String serializedJWT = null;
        HttpServletRequest req = request;
        serializedJWT = this.getJWTFromCookie(req);
        if (serializedJWT == null) {
            String loginURL = this.constructLoginURL(request);
            LOG.info("sending redirect to: " + loginURL);
            response.sendRedirect(loginURL);
        } else {
            String userName = null;
            SignedJWT jwtToken = null;
            boolean valid = false;
            try {
                jwtToken = SignedJWT.parse(serializedJWT);
                valid = this.validateToken(jwtToken);
                if (valid) {
                    userName = jwtToken.getJWTClaimsSet().getSubject();
                    LOG.info("USERNAME: " + userName);
                } else {
                    LOG.warn("jwtToken failed validation: " + jwtToken.serialize());
                }
            }
            catch (ParseException pe) {
                LOG.warn("Unable to parse the JWT token", pe);
            }
            if (valid) {
                LOG.debug("Issuing AuthenticationToken for user.");
                token = new AuthenticationToken(userName, userName, this.getType());
            } else {
                String loginURL = this.constructLoginURL(request);
                LOG.info("token validation failed - sending redirect to: " + loginURL);
                response.sendRedirect(loginURL);
            }
        }
        return token;
    }

    protected String getJWTFromCookie(HttpServletRequest req) {
        String serializedJWT = null;
        Cookie[] cookies = req.getCookies();
        if (cookies != null) {
            for (Cookie cookie : cookies) {
                if (!this.cookieName.equals(cookie.getName())) continue;
                LOG.info(this.cookieName + " cookie has been found and is being processed");
                serializedJWT = cookie.getValue();
                break;
            }
        }
        return serializedJWT;
    }

    @VisibleForTesting
    String constructLoginURL(HttpServletRequest request) {
        String delimiter = "?";
        if (this.authenticationProviderUrl.contains("?")) {
            delimiter = "&";
        }
        String loginURL = this.authenticationProviderUrl + delimiter + ORIGINAL_URL_QUERY_PARAM + request.getRequestURL().toString() + this.getOriginalQueryString(request);
        return loginURL;
    }

    private String getOriginalQueryString(HttpServletRequest request) {
        String originalQueryString = request.getQueryString();
        return originalQueryString == null ? "" : "?" + originalQueryString;
    }

    protected boolean validateToken(SignedJWT jwtToken) {
        boolean expValid;
        boolean audValid;
        boolean sigValid = this.validateSignature(jwtToken);
        if (!sigValid) {
            LOG.warn("Signature could not be verified");
        }
        if (!(audValid = this.validateAudiences(jwtToken))) {
            LOG.warn("Audience validation failed.");
        }
        if (!(expValid = this.validateExpiration(jwtToken))) {
            LOG.info("Expiration validation failed.");
        }
        return sigValid && audValid && expValid;
    }

    protected boolean validateSignature(SignedJWT jwtToken) {
        boolean valid = false;
        if (JWSObject.State.SIGNED == jwtToken.getState()) {
            LOG.debug("JWT token is in a SIGNED state");
            if (jwtToken.getSignature() != null) {
                LOG.debug("JWT token signature is not null");
                try {
                    RSASSAVerifier verifier = new RSASSAVerifier(this.publicKey);
                    if (jwtToken.verify(verifier)) {
                        valid = true;
                        LOG.debug("JWT token has been successfully verified");
                    } else {
                        LOG.warn("JWT signature verification failed.");
                    }
                }
                catch (JOSEException je) {
                    LOG.warn("Error while validating signature", je);
                }
            }
        }
        return valid;
    }

    protected boolean validateAudiences(SignedJWT jwtToken) {
        boolean valid = false;
        try {
            List<String> tokenAudienceList = jwtToken.getJWTClaimsSet().getAudience();
            if (this.audiences == null) {
                valid = true;
            } else {
                boolean found = false;
                for (String aud : tokenAudienceList) {
                    if (!this.audiences.contains(aud)) continue;
                    LOG.debug("JWT token audience has been successfully validated");
                    valid = true;
                    break;
                }
                if (!valid) {
                    LOG.warn("JWT audience validation failed.");
                }
            }
        }
        catch (ParseException pe) {
            LOG.warn("Unable to parse the JWT token.", pe);
        }
        return valid;
    }

    protected boolean validateExpiration(SignedJWT jwtToken) {
        boolean valid = false;
        try {
            Date expires = jwtToken.getJWTClaimsSet().getExpirationTime();
            if (expires == null || new Date().before(expires)) {
                LOG.debug("JWT token expiration date has been successfully validated");
                valid = true;
            } else {
                LOG.warn("JWT expiration date validation failed.");
            }
        }
        catch (ParseException pe) {
            LOG.warn("JWT expiration date validation failed.", pe);
        }
        return valid;
    }
}

