package eu.nets.sis.eident.demoapp.service;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.UnsupportedEncodingException;
import java.math.BigInteger;
import java.net.MalformedURLException;
import java.net.URI;
import java.net.URL;
import java.net.URLEncoder;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.security.KeyFactory;
import java.security.KeyStore;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.RSAPublicKeySpec;
import java.text.DateFormat;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.Calendar;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;

import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;

import org.apache.commons.codec.binary.Base64;
import org.apache.log4j.Logger;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWEObject;
import com.nimbusds.jose.JWSVerifier;
import com.nimbusds.jose.crypto.RSADecrypter;
import com.nimbusds.jose.crypto.RSASSAVerifier;
import com.nimbusds.jwt.EncryptedJWT;
import com.nimbusds.jwt.JWT;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.JWTParser;
import com.nimbusds.jwt.SignedJWT;
import com.nimbusds.oauth2.sdk.AuthorizationCode;
import com.nimbusds.oauth2.sdk.AuthorizationCodeGrant;
import com.nimbusds.oauth2.sdk.ErrorObject;
import com.nimbusds.oauth2.sdk.TokenRequest;
import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic;
import com.nimbusds.oauth2.sdk.auth.Secret;
import com.nimbusds.oauth2.sdk.http.CommonContentTypes;
import com.nimbusds.oauth2.sdk.http.HTTPRequest;
import com.nimbusds.oauth2.sdk.http.HTTPResponse;
import com.nimbusds.oauth2.sdk.id.ClientID;
import com.nimbusds.oauth2.sdk.util.JSONObjectUtils;
import com.nimbusds.openid.connect.sdk.AuthenticationErrorResponse;
import com.nimbusds.openid.connect.sdk.AuthenticationResponse;
import com.nimbusds.openid.connect.sdk.AuthenticationResponseParser;
import com.nimbusds.openid.connect.sdk.AuthenticationSuccessResponse;

import net.minidev.json.JSONArray;
import net.minidev.json.JSONObject;

/**
 * This class is responsible to construct the claims.
 *
 */
public class ClaimsService {

	public static final String DATE_FORMAT = "dd/MM/yyyy";
	private static final String TOKEN_ENDPOINT = "token_endpoint";
	private static final String JWKS_URI = "jwks_uri";
	private static final String ID_TOKEN = "id_token";
	private static final String KEYS = "keys";
	public static final String RSA_ALGORITHM = "RSA";
	public static final String NONCE = "nonce";
	private static Logger logger = Logger.getLogger(ClaimsService.class);
	
	private String errorDesc; //used in case of timeout and user cancel events

	public String getErrorDesc() {
		return errorDesc;
	}

	/**
	 * Retrieves ID Token from OIDC Token end point
	 * 
	 * @param responseURL URL with query string returned by authorization server
	 * @param redirectUri redirect URL matching the one sent for authorization request
	 * @param discoveryUri To discover OIDC supported claims and JWKS URI
	 * @param clientID Merchant identifier provided by Nets
	 * @param secretCode Authentication password to access token end point URL, provided by Nets
	 * @param storedNonce nonce sent to authorization server, used to mitigate replay attacks
	 * @param truststorePath SSL truststore to access discovery URL and Token end point URL
	 * @param truststorePassword SSL truststore password
	 * @return Claims in key value format
	 * @throws Exception is thrown in case of any error
	 */
	public Map<String, String> getClaims(String responseURL, String redirectUri, String discoveryUri, String clientID,
			String secretCode, String storedNonce, String truststorePath, String truststorePassword,
										 String decryptionKeystore, String decryptionKeystorePassword, String decryptionKeyAlias) throws Exception {

		Map<String, String> claimsMap = new HashMap<>();
		try {
			AuthenticationResponse authResp = AuthenticationResponseParser.parse(new URI(responseURL));

			if (authResp instanceof AuthenticationErrorResponse) {
				ErrorObject error = ((AuthenticationErrorResponse) authResp).getErrorObject();
				logger.error("Error code: " + error.getCode());
				logger.error("Error description: " + error.getDescription());
				this.errorDesc = error.getDescription();
				return null;
			}

			AuthenticationSuccessResponse successResponse = (AuthenticationSuccessResponse) authResp;

			/* This is a temporary solution. Merchant should handle this properly. State is used to avoid Cross-Site Request Forgery (CSRF, XSRF).
			 * The state returned by authorization server should match the one that is sent to authorization server.
			 */
			String requestedState = new SimpleDateFormat(DATE_FORMAT).format(Calendar.getInstance().getTime());
			String state = null;
			if (null != successResponse.getState()) {
				state = successResponse.getState().getValue();
			}

			if (null != state && !state.equals(requestedState)) {
				logger.error("Received state does not match with Requested state");
				throw new Exception("state mismatch");
			}

			AuthorizationCode authCode = successResponse.getAuthorizationCode();

			// Get the JSON Object from Discovery URL
			HTTPResponse discoveryHTTPResp = getHTTPRequest(discoveryUri, truststorePath, truststorePassword).send();

			JSONObject discoveryJsonObject = discoveryHTTPResp.getContentAsJSONObject();
			String tokenEndPoint = JSONObjectUtils.getString(discoveryJsonObject, TOKEN_ENDPOINT);
			String jwks_uri = JSONObjectUtils.getString(discoveryJsonObject, JWKS_URI);
			logger.debug("tokenEndPoint=" + tokenEndPoint + " , jwks_uri=" + jwks_uri);

			ClientSecretBasic clientSecretBasic = new ClientSecretBasic(new ClientID(clientID), new Secret(secretCode));
			TokenRequest tokenReq = new TokenRequest(new URI(tokenEndPoint), clientSecretBasic, new AuthorizationCodeGrant(authCode, new URI(redirectUri)));

			HTTPResponse tokenHTTPResp = tokenReq.toHTTPRequest().send();
			if(400 == tokenHTTPResp.getStatusCode()) {
				logger.error("Invalid response from token endpoint [statuscode=400]. Probably invalid secret code.  Please contact IN Groupe support.");
				throw new Exception("Invalid response from token endpoint [statuscode=400].");
			}

			// Get JSON object from Token response
			JSONObject tokenJsonObject = tokenHTTPResp.getContentAsJSONObject();
			String idtoken = JSONObjectUtils.getString(tokenJsonObject, ID_TOKEN);

			SignedJWT signedJWT;
			// Parse and check response
			JWT jwt = JWTParser.parse(idtoken);
			if (jwt instanceof EncryptedJWT) {
				logger.info("Encrypted IDToken found");
				signedJWT = decryptJWT(idtoken, decryptionKeystore, decryptionKeystorePassword, decryptionKeyAlias);
			} else {
				logger.info("No encryption found in IDToken");
				signedJWT = SignedJWT.parse(idtoken);
			}

			// Get JSON response from jwks_uri
			HTTPResponse jwksResp = getHTTPRequest(jwks_uri, truststorePath, truststorePassword).send();
			JSONObject jwksJsonObject = jwksResp.getContentAsJSONObject();
			JSONArray keys = JSONObjectUtils.getJSONArray(jwksJsonObject, KEYS);
			JSONObject jsonKey = null;

			for (int i=0; i<keys.size(); i++) {
				jsonKey = (JSONObject) keys.get(i);
				Object kid = jsonKey.get("kid");
				if (null != kid && kid.toString().equals(signedJWT.getHeader().getKeyID())) { //match kid from header
					break;
				}

				jsonKey = null; //not the kid we are looking for, go to the next
			}

			if (null != jsonKey) {
				logger.info("Verify signature for kid=" + signedJWT.getHeader().getKeyID());

				PublicKey publicKey = buildPublicKey(jsonKey);
				if (null != publicKey) {
					// Verify Signature
					JWSVerifier verifier = new RSASSAVerifier((RSAPublicKey) publicKey);
					if (signedJWT.verify(verifier)) {
						logger.info("Signature is verified");
					} else {
						logger.error("Couldn't verify signature");
						throw new JOSEException("Signature mismatch");
					}
				}
			} else {
				logger.error("No public key is found for IDToken signature verification");
				throw new Exception("No public key is found for IDToken signature verification");
			}

			JWTClaimsSet jwtClaimsSet = signedJWT.getJWTClaimsSet();
			logger.debug("claims: " + jwtClaimsSet);

			/* Merchant should handle nonce appropriately. none is used to avoid replay attacks.
			 * Validate the nonce received from authorization server against that one that is sent to it
			 */
			String nonce = jwtClaimsSet.getStringClaim("nonce");
			if (null != nonce && !nonce.equals(storedNonce)) {
				logger.error("Possible replay attack detected! The comparison of the nonce in the returned ID Token to the session "
								+ NONCE + " failed. Expected [" + storedNonce + "] got [" + nonce + "]");
				throw new Exception("nonce mismatch");
			}

			// prepare claims map
			claimsMap.put("subjectDN", jwtClaimsSet.getStringClaim("dn"));
			claimsMap.put("ssn", getSSN(jwtClaimsSet));
			claimsMap.put("ssn_issuing_country", getSsnIssuingCountry(jwtClaimsSet));
			claimsMap.put("pid", getPID(jwtClaimsSet));

			if (jwtClaimsSet.getClaim("amr") instanceof String) {
				claimsMap.put("providerID", jwtClaimsSet.getStringClaim("amr"));
			} else {
				claimsMap.put("providerID", jwtClaimsSet.getStringArrayClaim("amr")[0]);
			}

			claimsMap.put("certPolicyOID", jwtClaimsSet.getStringClaim("certpolicyoid"));
			claimsMap.put("dob", jwtClaimsSet.getStringClaim("birthdate"));
			claimsMap.put("authmethod", "OAuth 2.0");
			claimsMap.put("rawClaims", prettyPrint(jwtClaimsSet));
			String cert = jwtClaimsSet.getStringClaim("certificate");
			if (null != cert && !cert.isEmpty()) {
				claimsMap.put("cert", cert);
			} else {
				claimsMap.put("cert", "No. Cert. Here. :-(");
			}
		} catch (Exception e) {
			logger.error("Exception while retrieving claims", e);
			throw e;
		}

		return claimsMap;
	}

	/***
	 * Decrypt the encrypted id-token
	 * @param encryptedRequest : encrypted id-token string
	 * @param decryptionKeystore : merchant private key (p12) for decryption.
	 * @param decryptionKeystorePassword : merchant private key (p12) password.
	 * @return SignedJWT decrypted SignedJWT id-token
	 * @throws Exception
	 */
	public SignedJWT decryptJWT(String encryptedRequest, String decryptionKeystore, String decryptionKeystorePassword, String decryptionKeyAlias) throws Exception {
		logger.info("Decrypting IDToken");
		PrivateKey privateKey = getPrivateKey(decryptionKeystore, decryptionKeystorePassword, decryptionKeyAlias);
		JWEObject jweData = JWEObject.parse(encryptedRequest);
		jweData.decrypt(new RSADecrypter(privateKey));
		SignedJWT signedJWT = jweData.getPayload().toSignedJWT();
		logger.info("Decryption is done");

		return signedJWT;
	}

	/***
	 * get private key from the keystore.
	 * @param decryptionKeystore
	 * @param decryptionKeystorePassword
	 * @return
	 * @throws Exception
	 */
	private PrivateKey getPrivateKey(String decryptionKeystore, String decryptionKeystorePassword, String decryptionKeyAlias) throws Exception {
		char[] passphrase = decryptionKeystorePassword.toCharArray();

		KeyStore keystore = KeyStore.getInstance("PKCS12");

		File targetDir = new File(".");
		logger.info("Target Directory Path=" + targetDir.getAbsolutePath());

		File file = new File(decryptionKeystore);
		if (file.exists()) {
			logger.info("Keystore File exists in the current folder or at your provided path!");
			keystore.load(Files.newInputStream(Paths.get(decryptionKeystore)), passphrase);
		} else {
			File file1 = new File("target/"+decryptionKeystore);
			if (file1.exists()) {
				logger.info("Keystore File exists inside target folder");
				keystore.load(Files.newInputStream(new File("target/"+ decryptionKeystore).toPath()), passphrase);
			} else {
				logger.info("File not found elsewhere, hence fetching from classpath");
				ClassLoader classLoader = getClass().getClassLoader();
				InputStream inputStream = classLoader.getResourceAsStream("/"+decryptionKeystore);
				keystore.load(inputStream, passphrase);
			}
		}
		logger.info("keystore loaded");
		PrivateKey privateKey = (PrivateKey) keystore.getKey(decryptionKeyAlias, passphrase);
		logger.info("PrivateKey with alias = " +decryptionKeyAlias + (privateKey == null ? " Not found" : " Found"));
		return privateKey;
	}

	/**
	 * Builds public key
	 * 
	 * @param keyJson JSONObject containing RSA PublicKey
	 * @return PublicKey RSA PublicKey
	 * @throws com.nimbusds.oauth2.sdk.ParseException is thrown in case of error
	 * @throws InvalidKeySpecException is thrown in case of error
	 * @throws NoSuchAlgorithmException is thrown in case of error
	 */
	private PublicKey buildPublicKey(JSONObject keyJson) throws com.nimbusds.oauth2.sdk.ParseException, InvalidKeySpecException, NoSuchAlgorithmException {
		if (null == keyJson) {
			return null;
		}
		String kty = JSONObjectUtils.getString(keyJson, "kty");
		if (RSA_ALGORITHM.equals(kty)) {
			return buildRSAPublicKey(kty, keyJson);
		}
		return null;
	}

	/**
	 * Builds RSA Public key
	 * 
	 * @param kty Algorithm type (RSA)
	 * @param keyJson JSONObject, contains RSA PublicKey
	 * @return RSA PublicKey
	 * @throws InvalidKeySpecException
	 * @throws NoSuchAlgorithmException
	 * @throws com.nimbusds.oauth2.sdk.ParseException
	 */
	private PublicKey buildRSAPublicKey(String kty, JSONObject keyJson) throws InvalidKeySpecException, NoSuchAlgorithmException, com.nimbusds.oauth2.sdk.ParseException {
		String n = JSONObjectUtils.getString(keyJson, "n");
		String e = JSONObjectUtils.getString(keyJson, "e");
		Base64 b64 = new Base64();
		BigInteger modulus = new BigInteger(1, b64.decode(n));
		BigInteger publicExponent = new BigInteger(1, b64.decode(e));
		return KeyFactory.getInstance(kty).generatePublic(new RSAPublicKeySpec(modulus, publicExponent));
	}

	/**
	 * Returns JSON object in String format
	 * 
	 * @param obj JSON Object to be converted
	 * @return JSON String
	 * @throws IOException
	 */
	private String prettyPrint(Object obj) throws IOException {
		ObjectMapper mapper = new ObjectMapper();
		Object json = mapper.readValue(obj.toString(), Object.class);
		return mapper.writerWithDefaultPrettyPrinter().writeValueAsString(json);
	}

	/**
	 * Gets SSN from claims set
	 * 
	 * @param jwtClaimsSet JWT claims set
	 * @return SSN value in String format
	 * @throws ParseException
	 */
	private String getSSN(JWTClaimsSet jwtClaimsSet) throws ParseException {
		String[] ssnArray = { "ssn", "dk_ssn", "no_ssn", "se_ssn", "fi_ssn", "no_nco_vsuserid" };
		String ssn = null;
		for (String name : ssnArray) {
			ssn = jwtClaimsSet.getStringClaim(name);
			if (null != ssn && !ssn.isEmpty()) {
				return ssn;
			}
		}
		return ssn;
	}
	
	/**
	 * Gets ssnIssuingCountry from claims set
	 * 
	 * @param jwtClaimsSet JWT claims set
	 * @return ssnIssuingCountry value in String format
	 * @throws ParseException
	 */
	private String getSsnIssuingCountry(JWTClaimsSet jwtClaimsSet) throws ParseException {
		String ssnIssuingCountry = jwtClaimsSet.getStringClaim("ssn_issuing_country");
		return ssnIssuingCountry;
	}

	/**
	 * Gets PID from claims set
	 * 
	 * @param jwtClaimsSet JWT claims set
	 * @return PID value in String format
	 * @throws ParseException
	 */
	private String getPID(JWTClaimsSet jwtClaimsSet) throws ParseException {
		String[] pidArray = { "no_bid_pid", "dk_dan_pid", "no_bp_pid","smartid_pid","mobileid_pid" };
		String pid = null;
		for (String name : pidArray) {
			pid = jwtClaimsSet.getStringClaim(name);
			if (null != pid && !pid.isEmpty()) {
				return pid;
			}
		}
		return pid;
	}

	/**
	 * Generates HTTPRequest with the given input
	 * 
	 * @param url Request url
	 * @param truststorePath SSL truststore path
	 * @param truststorePassword SSL truststore password
	 * @return Generated HTTPRequest object
	 * @throws Exception is thrown in case of error
	 */
	public HTTPRequest getHTTPRequest(String url, String truststorePath, String truststorePassword) throws Exception {
		// Use HTTP GET
		HTTPRequest request;
		try {
			request = new HTTPRequest(HTTPRequest.Method.GET, new URL(url));
		} catch (MalformedURLException ex) {
			logger.error("Exception while getting HTTPRequest", ex);
			throw new RuntimeException("Exception while getting HTTPRequest");
		}
		request.setContentType(CommonContentTypes.APPLICATION_URLENCODED);
		SSLSocketFactory sf = getSSLSocketFactory(truststorePath, truststorePassword);
		if (null != sf) {
			HTTPRequest.setDefaultSSLSocketFactory(sf);
		}
		return request;
	}

	/**
	 * Returns SSLSocketFactory
	 * 
	 * @param truststorePath SSL truststore file path
	 * @param truststorePassword SSL truststore password
	 * @return SSLSocketFactory object
	 * @throws Exception is thrown in case of error
	 */
	public SSLSocketFactory getSSLSocketFactory(String truststorePath, String truststorePassword) throws Exception {
		SSLSocketFactory sf;
		if (null == truststorePath || null == truststorePassword) {
			throw new Exception("truststore path or password is null");
		}
		try {
			char[] passphrase = truststorePassword.toCharArray();
			KeyStore keystore = KeyStore.getInstance("JKS");
			
			File file = new File(truststorePath);
			if (file.exists()) {
				logger.debug("Truststore File exists in the current folder where you are!");
				keystore.load(Files.newInputStream(Paths.get(truststorePath)), passphrase);

			} else {
				File file1 = new File("target/"+truststorePath);
				if (file1.exists()) {
					logger.debug("Truststore File exists inside target folder");
					keystore.load(Files.newInputStream(Paths.get("target/" + truststorePath)), passphrase);
				} else {
					logger.debug("File not found elsewhere, hence fetching from classpath");
					ClassLoader classLoader = getClass().getClassLoader();
					InputStream inputStream = classLoader.getResourceAsStream(truststorePath);
					keystore.load(inputStream, passphrase);
				}
			}
						
			TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");
			tmf.init(keystore);

			SSLContext context = SSLContext.getInstance("TLS");
			TrustManager[] trustManagers = tmf.getTrustManagers();

			context.init(null, trustManagers, null);

			sf = context.getSocketFactory();

		} catch (Exception exp) {
			logger.error("Exception while getting socket factory", exp);
			throw new RuntimeException("Exception while getting SSLSocketFactory");
		}
		return sf;
	}
	
	/**
	 * Generating EIdent URL
	 * 
	 * @param mid
	 * @param scope
	 * @param itsURL
	 * @return
	 */
	public String generateEIdentUrl(String mid, String scope, String itsURL) {
		DateFormat df = new SimpleDateFormat("dd/MM/yyyy");
		String state = df.format(new Date());
		
		StringBuilder eIdentUrlSb = new StringBuilder();
		try {
			eIdentUrlSb.append(itsURL);
			eIdentUrlSb.append("?client_id=");
			eIdentUrlSb.append(URLEncoder.encode(mid, "UTF-8"));
			eIdentUrlSb.append("&scope=");
			eIdentUrlSb.append(URLEncoder.encode(scope, "UTF-8")); // Determines the claims to be retrieved
			eIdentUrlSb.append("&state=");
			eIdentUrlSb.append(URLEncoder.encode(state, "UTF-8"));
			eIdentUrlSb.append("&response_type=code");
			eIdentUrlSb.append("&ntdr-app-doclist=true");
		} catch (UnsupportedEncodingException exp) {
			logger.error("Exception while encoding EIdent URL", exp);
			throw new RuntimeException("Exception while encoding EIdent URL");
		}
		return eIdentUrlSb.toString();
	}
}
