/*
 * Decompiled with CFR 0.152.
 */
package com.amazon.athena.jdbc.authentication;

import com.amazon.athena.jdbc.authentication.IdpCredentialsProvider;
import com.amazon.athena.jdbc.configuration.ConnectionParameter;
import com.amazon.athena.jdbc.configuration.ConnectionParameters;
import com.amazon.athena.jdbc.support.AuthenticationException;
import com.amazon.athena.jdbc.support.EndpointHelper;
import com.amazon.athena.jdbc.support.ProxyHelper;
import com.amazon.athena.logging.AthenaLogger;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.net.URI;
import java.time.Clock;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Stream;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.parsers.ParserConfigurationException;
import javax.xml.xpath.XPath;
import javax.xml.xpath.XPathConstants;
import javax.xml.xpath.XPathExpression;
import javax.xml.xpath.XPathExpressionException;
import javax.xml.xpath.XPathFactory;
import org.apache.commons.codec.binary.Base64;
import org.w3c.dom.Document;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;
import org.xml.sax.SAXException;
import software.amazon.awssdk.auth.credentials.AnonymousCredentialsProvider;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.http.apache.ApacheHttpClient;
import software.amazon.awssdk.http.apache.ProxyConfiguration;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.lakeformation.LakeFormationClient;
import software.amazon.awssdk.services.lakeformation.LakeFormationClientBuilder;
import software.amazon.awssdk.services.lakeformation.model.AssumeDecoratedRoleWithSamlRequest;
import software.amazon.awssdk.services.lakeformation.model.AssumeDecoratedRoleWithSamlResponse;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.StsClientBuilder;
import software.amazon.awssdk.services.sts.model.AssumeRoleWithSamlRequest;
import software.amazon.awssdk.services.sts.model.AssumeRoleWithSamlResponse;
import software.amazon.awssdk.services.sts.model.Credentials;
import software.amazon.awssdk.utils.Pair;

abstract class SamlCredentialsProvider
extends IdpCredentialsProvider
implements AwsCredentialsProvider {
    private static final AthenaLogger logger = AthenaLogger.of(SamlCredentialsProvider.class);
    private static final String ROLE_PATTERN = "arn:aws[-a-z]*:iam::\\d*:role/\\S+";
    private static final String SAML_PROVIDER_PATTERN = "arn:aws[-a-z]*:iam::\\d*:saml-provider/\\S+";
    private static final int EXPIRATION_THRESHOLD_SECS = 180;
    private final AssumeRoleWithSamlRequest.Builder assumeRoleWithSamlRequestFactory;
    private final AssumeDecoratedRoleWithSamlRequest.Builder assumeDecoratedRoleWithSamlRequestFactory;
    private final StsClientBuilder stsClientFactory;
    private final LakeFormationClientBuilder lakeFormationClientFactory;
    private final DocumentBuilderFactory documentBuilderFactory;
    private final String preferredRole;
    private final Integer roleSessionDuration;
    private final Region region;
    private final boolean lakeFormationEnabled;
    private Map<ConnectionParameter<?>, String> parameters;
    private Credentials stsCredentials;
    private AssumeDecoratedRoleWithSamlResponse lakeFormationCredentials;
    protected final Clock clock;

    protected SamlCredentialsProvider(AssumeRoleWithSamlRequest.Builder assumeRoleWithSamlRequestFactory, AssumeDecoratedRoleWithSamlRequest.Builder assumeDecoratedRoleWithSamlRequestFactory, StsClientBuilder stsClientFactory, LakeFormationClientBuilder lakeFormationClientFactory, DocumentBuilderFactory documentBuilderFactory, Clock clock, String preferredRole, Integer roleSessionDuration, Region region, boolean lakeFormationEnabled, Map<ConnectionParameter<?>, String> parameters) {
        this.assumeRoleWithSamlRequestFactory = assumeRoleWithSamlRequestFactory == null ? AssumeRoleWithSamlRequest.builder() : assumeRoleWithSamlRequestFactory;
        this.assumeDecoratedRoleWithSamlRequestFactory = assumeDecoratedRoleWithSamlRequestFactory == null ? AssumeDecoratedRoleWithSamlRequest.builder() : assumeDecoratedRoleWithSamlRequestFactory;
        this.stsClientFactory = stsClientFactory == null ? StsClient.builder() : stsClientFactory;
        this.lakeFormationClientFactory = lakeFormationClientFactory == null ? LakeFormationClient.builder() : lakeFormationClientFactory;
        this.documentBuilderFactory = documentBuilderFactory == null ? DocumentBuilderFactory.newInstance() : documentBuilderFactory;
        this.clock = clock == null ? Clock.systemDefaultZone() : clock;
        this.preferredRole = preferredRole;
        this.roleSessionDuration = roleSessionDuration;
        this.region = region;
        this.lakeFormationEnabled = lakeFormationEnabled;
        this.parameters = parameters;
    }

    protected abstract String getSamlAssertion();

    @Override
    public AwsCredentials resolveCredentials() {
        boolean needsUpdate;
        if (this.lakeFormationEnabled) {
            boolean needsUpdate2;
            boolean bl = needsUpdate2 = this.lakeFormationCredentials == null || this.lakeFormationCredentials.expiration().compareTo(this.clock.instant().plusSeconds(180L)) < 0;
            if (needsUpdate2) {
                this.lakeFormationCredentials = this.obtainCredentialsFromLakeFormation(this.getSamlAssertion());
            }
            return AwsSessionCredentials.create(this.lakeFormationCredentials.accessKeyId(), this.lakeFormationCredentials.secretAccessKey(), this.lakeFormationCredentials.sessionToken());
        }
        boolean bl = needsUpdate = this.stsCredentials == null || this.stsCredentials.expiration().compareTo(this.clock.instant().plusSeconds(180L)) < 0;
        if (needsUpdate) {
            this.stsCredentials = this.obtainCredentialsFromSts(this.getSamlAssertion());
        }
        return AwsSessionCredentials.create(this.stsCredentials.accessKeyId(), this.stsCredentials.secretAccessKey(), this.stsCredentials.sessionToken());
    }

    private Credentials obtainCredentialsFromSts(String encodedSamlAssertion) {
        this.getStsEndpoint().ifPresent(this.stsClientFactory::endpointOverride);
        ProxyHelper.getSyncProxyConfiguration(this.parameters).ifPresent(proxyConfiguration -> {
            StsClientBuilder cfr_ignored_0 = (StsClientBuilder)this.stsClientFactory.httpClientBuilder(this.getHttpClientBuilder((ProxyConfiguration)proxyConfiguration));
        });
        StsClient stsClient = (StsClient)((StsClientBuilder)((StsClientBuilder)this.stsClientFactory.region(this.region)).credentialsProvider(AnonymousCredentialsProvider.create())).build();
        Pair<String, String> roleAndPrincipal = this.extractRoleAndPrincipal(encodedSamlAssertion);
        AssumeRoleWithSamlRequest request = (AssumeRoleWithSamlRequest)this.assumeRoleWithSamlRequestFactory.samlAssertion(encodedSamlAssertion).roleArn(roleAndPrincipal.left()).principalArn(roleAndPrincipal.right()).durationSeconds(this.roleSessionDuration).build();
        logger.debug("Obtaining credentials from STS", new Object[0]);
        logger.trace("Sending AssumeRoleWithSaml request: {}", request);
        AssumeRoleWithSamlResponse response = stsClient.assumeRoleWithSAML(request);
        logger.info("Obtained credentials from STS", new Object[0]);
        return response.credentials();
    }

    private AssumeDecoratedRoleWithSamlResponse obtainCredentialsFromLakeFormation(String encodedSamlAssertion) {
        this.getLakeFormationEndpoint().ifPresent(this.lakeFormationClientFactory::endpointOverride);
        ProxyHelper.getSyncProxyConfiguration(this.parameters).ifPresent(proxyConfiguration -> {
            LakeFormationClientBuilder cfr_ignored_0 = (LakeFormationClientBuilder)this.lakeFormationClientFactory.httpClientBuilder(this.getHttpClientBuilder((ProxyConfiguration)proxyConfiguration));
        });
        LakeFormationClient lakeFormationClient = (LakeFormationClient)((LakeFormationClientBuilder)((LakeFormationClientBuilder)this.lakeFormationClientFactory.region(this.region)).credentialsProvider(AnonymousCredentialsProvider.create())).build();
        Pair<String, String> roleAndPrincipal = this.extractRoleAndPrincipal(encodedSamlAssertion);
        AssumeDecoratedRoleWithSamlRequest request = (AssumeDecoratedRoleWithSamlRequest)this.assumeDecoratedRoleWithSamlRequestFactory.samlAssertion(encodedSamlAssertion).roleArn(roleAndPrincipal.left()).principalArn(roleAndPrincipal.right()).durationSeconds(this.roleSessionDuration).build();
        logger.debug("Obtaining credentials from Lake Formation", new Object[0]);
        logger.trace("Sending AssumeDecoratedRoleWithSaml request: {}", String.format("AssumeDecoratedRoleWithSamlRequest(SAMLAssertion=*******, RoleArn=%s, PrincipalArn=%s, DurationSeconds=%s)", request.roleArn(), request.principalArn(), request.durationSeconds()));
        AssumeDecoratedRoleWithSamlResponse response = lakeFormationClient.assumeDecoratedRoleWithSAML(request);
        logger.info("Obtained credentials from Lake Formation", new Object[0]);
        return response;
    }

    private Optional<URI> getStsEndpoint() {
        Optional<String> stsEndpoint = ConnectionParameters.STS_ENDPOINT_PARAMETER.findValue(this.parameters);
        return stsEndpoint.map(endpoint -> EndpointHelper.constructEndpointUri(endpoint, "STS"));
    }

    private Optional<URI> getLakeFormationEndpoint() {
        Optional<String> lakeFormationEndpoint = ConnectionParameters.LAKE_FORMATION_ENDPOINT_PARAMETER.findValue(this.parameters);
        return lakeFormationEndpoint.map(endpoint -> EndpointHelper.constructEndpointUri(endpoint, "Lake Formation"));
    }

    private ApacheHttpClient.Builder getHttpClientBuilder(ProxyConfiguration proxyConfiguration) {
        return ApacheHttpClient.builder().proxyConfiguration(proxyConfiguration);
    }

    private Pair<String, String> extractRoleAndPrincipal(String samlAssertion) {
        NodeList nodeList = this.findRoleSamlAttributes(samlAssertion);
        Map<String, String> roles = SamlCredentialsProvider.findIamRolesAndPrincipals(nodeList);
        return this.findPreferredRoleAndPrincipal(samlAssertion, roles);
    }

    private Document parseIntoDom(byte[] xmlSamlAssertion) {
        try {
            this.documentBuilderFactory.setFeature("http://apache.org/xml/features/disallow-doctype-decl", true);
            this.documentBuilderFactory.setXIncludeAware(false);
            this.documentBuilderFactory.setExpandEntityReferences(false);
            this.documentBuilderFactory.setFeature("http://xml.org/sax/features/external-parameter-entities", false);
            this.documentBuilderFactory.setFeature("http://xml.org/sax/features/external-general-entities", false);
            DocumentBuilder db = this.documentBuilderFactory.newDocumentBuilder();
            return db.parse(new ByteArrayInputStream(xmlSamlAssertion));
        }
        catch (IOException | ParserConfigurationException | SAXException e) {
            throw new AuthenticationException("An error occurred while parsing the SAML assertion into DOM", e);
        }
    }

    private NodeList findRoleSamlAttributes(String samlAssertion) {
        Document doc = this.parseIntoDom(Base64.decodeBase64(samlAssertion));
        XPath xPath = XPathFactory.newInstance().newXPath();
        String expression = "//*[local-name()='Attribute'][@Name='https://aws.amazon.com/SAML/Attributes/Role']/*[local-name()='AttributeValue']/text()";
        try {
            XPathExpression compiledExpression = xPath.compile("//*[local-name()='Attribute'][@Name='https://aws.amazon.com/SAML/Attributes/Role']/*[local-name()='AttributeValue']/text()");
            NodeList roleSamlAttributes = (NodeList)compiledExpression.evaluate(doc, XPathConstants.NODESET);
            if (roleSamlAttributes.getLength() == 0) {
                logger.error("No role attribute found in the SAML assertion -- " + new String(Base64.decodeBase64(samlAssertion)), new Object[0]);
                throw new AuthenticationException("No role attribute found in the SAML assertion");
            }
            return roleSamlAttributes;
        }
        catch (XPathExpressionException e) {
            throw new AuthenticationException("An error occurred while attempting to find the SAML role attribute", e);
        }
    }

    private static Map<String, String> findIamRolesAndPrincipals(NodeList samlRoleNodes) {
        HashMap<String, String> roles = new HashMap<String, String>();
        if (samlRoleNodes != null) {
            for (int i = 0; i < samlRoleNodes.getLength(); ++i) {
                Node node = samlRoleNodes.item(i);
                String roleAndProviderPair = node.getNodeValue();
                String[] arns = roleAndProviderPair.split(",");
                Optional<String> role = Stream.of(arns).filter(arn -> arn.matches(ROLE_PATTERN)).findAny();
                Optional<String> principal = Stream.of(arns).filter(arn -> arn.matches(SAML_PROVIDER_PATTERN)).findAny();
                if (!role.isPresent() || !principal.isPresent()) continue;
                roles.put(role.get(), principal.get());
            }
        }
        return roles;
    }

    private Pair<String, String> findPreferredRoleAndPrincipal(String samlAssertion, Map<String, String> roles) {
        if (this.preferredRole != null) {
            String principal = roles.get(this.preferredRole);
            if (principal == null) {
                logger.error("Preferred role not found in SAML assertion -- " + new String(Base64.decodeBase64(samlAssertion)), new Object[0]);
                throw new AuthenticationException("Preferred role not found in SAML assertion");
            }
            return Pair.of(this.preferredRole, principal);
        }
        return roles.entrySet().stream().findFirst().map(entry -> Pair.of(entry.getKey(), entry.getValue())).orElseThrow(() -> new AuthenticationException("None of the role attributes in the SAML assertion contain a role"));
    }

    protected String decodeHtmlCharacterReferences(String html) {
        StringBuilder sb = new StringBuilder(html.length());
        int i = 0;
        int length = html.length();
        while (i < length) {
            char c = html.charAt(i);
            if (c != '&') {
                sb.append(c);
                ++i;
                continue;
            }
            if (html.startsWith("&amp;", i)) {
                sb.append('&');
                i += 5;
                continue;
            }
            if (html.startsWith("&apos;", i)) {
                sb.append('\'');
                i += 6;
                continue;
            }
            if (html.startsWith("&quot;", i)) {
                sb.append('\"');
                i += 6;
                continue;
            }
            if (html.startsWith("&lt;", i)) {
                sb.append('<');
                i += 4;
                continue;
            }
            if (html.startsWith("&gt;", i)) {
                sb.append('>');
                i += 4;
                continue;
            }
            if (html.startsWith("&#x2b;", i)) {
                sb.append('+');
                i += 6;
                continue;
            }
            if (html.startsWith("&#x3d;", i)) {
                sb.append('=');
                i += 6;
                continue;
            }
            sb.append(c);
            ++i;
        }
        return sb.toString();
    }
}

