
package com.beem.project.beem.smack.caps;

import org.jivesoftware.smack.Connection;
import org.jivesoftware.smack.XMPPException;
import org.jivesoftware.smackx.packet.DiscoverInfo;
import org.jivesoftware.smack.packet.Packet;
import org.jivesoftware.smack.packet.PacketExtension;
import org.jivesoftware.smackx.ServiceDiscoveryManager;
import org.jivesoftware.smack.util.collections.ReferenceMap;
import org.jivesoftware.smack.PacketListener;
import org.jivesoftware.smack.filter.PacketFilter;
import org.jivesoftware.smack.filter.PacketExtensionFilter;

import java.util.Map;
import java.util.Iterator;
import java.util.Comparator;
import java.util.List;
import java.util.ArrayList;
import java.util.Collections;
import java.security.NoSuchAlgorithmException;
import java.security.MessageDigest;

import org.jivesoftware.smack.util.StringUtils;

public class CapsManager {
    // the verCache should be stored on disk
    Map<String, DiscoverInfo> mVerCache = new ReferenceMap<String, DiscoverInfo>();
    Map<String, DiscoverInfo> mJidCache = new ReferenceMap<String, DiscoverInfo>();

    private ServiceDiscoveryManager mSdm;
    private Connection mConnection;
    private List<String> mSupportedAlgorithm = new ArrayList<String>();

    public CapsManager(ServiceDiscoveryManager sdm, Connection conn) {
	mSdm = sdm;
	mConnection = conn;
	init();
    }

    private void init() {
	initSupportedAlgorithm();
	PacketFilter filter = new PacketExtensionFilter("c", "http://jabber.org/protocol/caps");
	mConnection.addPacketListener( new PacketListener() {
	    public void processPacket(Packet packet) {
		PacketExtension p = packet.getExtension("c", "http://jabber.org/protocol/caps");
		CapsExtension caps = (CapsExtension) p;
		if (!mVerCache.containsKey(caps.getVer()))
		    validate(packet.getFrom(), caps.getVer(), caps.getHash());
	    }
	}, filter);
    }

    public DiscoverInfo getDiscoverInfo(String ver) {
	return mVerCache.get(ver);
    }

    public DiscoverInfo getDiscoverInfo(String jid, String ver) {
	DiscoverInfo info = mVerCache.get(ver);
	if (info == null)
	    info = mJidCache.get(jid);
	return info;
    }

    /** 
     * 
     * 
     * @param jid 
     * @param ver 
     * @param hashMethod 
     * @return 
     */
    private boolean validate(String jid, String ver, String hashMethod) {
	try {
	    DiscoverInfo info = mSdm.discoverInfo(jid);
	    if (!mSupportedAlgorithm.contains(hashMethod)) {
		mJidCache.put(jid, info);
		return false;
	    }
	    String v = calculateVer(info, hashMethod);
	    boolean res = v.equals(ver);
	    if (res)
		mVerCache.put(ver, info);
	    return res;
	} catch (XMPPException e) {
	    e.printStackTrace();
	    return false;
	} catch (NoSuchAlgorithmException e) {
	    e.printStackTrace();
	    return false;
	}
    }

    private String calculateVer(DiscoverInfo info, String hashMethod) throws NoSuchAlgorithmException {
	StringBuilder S = new StringBuilder();
	for(DiscoverInfo.Identity identity : getSortedIdentity(info)) {
	    String c = identity.getCategory();
	    if (c != null)
		S.append(c);
	    S.append('/');
	    c = identity.getType();
	    if (c != null)
		S.append(c);
	    S.append('/');
	    // Should add lang but it is not available
//             c = identity.getType();
//             if (c != null)
//                 S.append(c);
	    S.append('/');
	    c = identity.getName();
	    if (c != null)
		S.append(c);
	    S.append('<');
	}
	for (String f : getSortedFeature(info)) {
	    S.append(f);
	    S.append('<');
	}
	// Should add data form (XEP 0128) but it is not available
	byte[] hash = getHash(hashMethod, S.toString().getBytes());
	return StringUtils.encodeBase64(hash);
    }

    private List<DiscoverInfo.Identity> getSortedIdentity(DiscoverInfo info) {
	List<DiscoverInfo.Identity> result = new ArrayList<DiscoverInfo.Identity>();
	Iterator<DiscoverInfo.Identity> it = info.getIdentities();
	while (it.hasNext()) {
	    DiscoverInfo.Identity id = it.next();
	    result.add(id);
	}
	Collections.sort(result, new Comparator<DiscoverInfo.Identity>() {
	    public int compare(DiscoverInfo.Identity o1, DiscoverInfo.Identity o2) {

		String cat1 = o1.getCategory();
		if (cat1 == null) cat1 = "";
		String cat2 = o2.getCategory();
		if (cat2 == null) cat2 = "";
		int res = cat1.compareTo(cat2);
		if (res != 0)
		    return res;
		String type1 = o1.getType();
		if (type1 == null) type1 = "";
		String type2 = o2.getCategory();
		if (type2 == null) type2 = "";
		res = type1.compareTo(type2);
		if (res != 0)
		    return res;
		// should compare lang but not avalaible
		return 0;
	    }
	});
	return result;
    }

    private List<String> getSortedFeature(DiscoverInfo info) {
	List<String> result = new ArrayList<String>();
	Iterator<DiscoverInfo.Feature> it = info.getFeatures();
	while (it.hasNext()) {
	    DiscoverInfo.Feature feat = it.next();
	    result.add(feat.getVar());
	}
	Collections.sort(result);
	return result;
    }

    private byte[] getHash(String algo, byte[] data) throws NoSuchAlgorithmException {
	    MessageDigest md = MessageDigest.getInstance(algo);
	    return md.digest(data);
    }

    private void initSupportedAlgorithm() {
	String algo[] = new String[] {"md2", "md5", "sha-1", "sha-224", "sha-256", "sha-384", "sha-512" };
	for (String a : algo) {
	    try {
		MessageDigest md = MessageDigest.getInstance(a);
		mSupportedAlgorithm.add(a);
	    } catch(NoSuchAlgorithmException e) {
		System.err.println("Hash algorithm " + a + " not supported");
	    }
	}
    }

}
