app/src/main/java/com/isode/stroke/sasl/SCRAMSHA1ClientAuthenticator.java
changeset 1040 197a85a35cba
parent 1018 8daca77fabc1
equal deleted inserted replaced
1039:7d6f2526244a 1040:197a85a35cba
       
     1 /*
       
     2  * Copyright (c) 2010, Isode Limited, London, England.
       
     3  * All rights reserved.
       
     4  */
       
     5 /*
       
     6  * Copyright (c) 2010, Remko Tronçon.
       
     7  * All rights reserved.
       
     8  */
       
     9 package com.isode.stroke.sasl;
       
    10 
       
    11 import com.isode.stroke.base.ByteArray;
       
    12 import com.isode.stroke.stringcodecs.Base64;
       
    13 import com.isode.stroke.stringcodecs.HMACSHA1;
       
    14 import com.isode.stroke.stringcodecs.PBKDF2;
       
    15 import com.isode.stroke.stringcodecs.SHA1;
       
    16 import java.text.Normalizer;
       
    17 import java.text.Normalizer.Form;
       
    18 import java.util.HashMap;
       
    19 import java.util.Map;
       
    20 
       
    21 public class SCRAMSHA1ClientAuthenticator extends ClientAuthenticator {
       
    22 
       
    23     static String escape(String s) {
       
    24         String result = "";
       
    25         for (int i = 0; i < s.length(); ++i) {
       
    26             if (s.charAt(i) == ',') {
       
    27                 result += "=2C";
       
    28             } else if (s.charAt(i) == '=') {
       
    29                 result += "=3D";
       
    30             } else {
       
    31                 result += s.charAt(i);
       
    32             }
       
    33         }
       
    34         return result;
       
    35     }
       
    36 
       
    37     public SCRAMSHA1ClientAuthenticator(String nonce) {
       
    38         this(nonce, false);
       
    39     }
       
    40     public SCRAMSHA1ClientAuthenticator(String nonce, boolean useChannelBinding) {
       
    41         super(useChannelBinding ? "SCRAM-SHA-1-PLUS" : "SCRAM-SHA-1");
       
    42         step = Step.Initial;
       
    43         clientnonce = nonce;
       
    44         this.useChannelBinding = useChannelBinding;
       
    45     }
       
    46 
       
    47     public void setTLSChannelBindingData(ByteArray channelBindingData) {
       
    48         tlsChannelBindingData = channelBindingData;
       
    49     }
       
    50 
       
    51     public ByteArray getResponse() {
       
    52         if (step.equals(Step.Initial)) {
       
    53             return ByteArray.plus(getGS2Header(), getInitialBareClientMessage());
       
    54         } else if (step.equals(Step.Proof)) {
       
    55             ByteArray clientKey = HMACSHA1.getResult(saltedPassword, new ByteArray("Client Key"));
       
    56             ByteArray storedKey = SHA1.getHash(clientKey);
       
    57             ByteArray clientSignature = HMACSHA1.getResult(storedKey, authMessage);
       
    58             ByteArray clientProof = clientKey;
       
    59             byte[] clientProofData = clientProof.getData();
       
    60             for (int i = 0; i < clientProofData.length; ++i) {
       
    61                 clientProofData[i] ^= clientSignature.getData()[i];
       
    62             }
       
    63             ByteArray result = getFinalMessageWithoutProof().append(",p=").append(Base64.encode(clientProof));
       
    64             return result;
       
    65         } else {
       
    66             return null;
       
    67         }
       
    68     }
       
    69 
       
    70     public boolean setChallenge(ByteArray challenge) {
       
    71         if (step.equals(Step.Initial)) {
       
    72             if (challenge == null) {
       
    73                 return false;
       
    74             }
       
    75             initialServerMessage = challenge;
       
    76 
       
    77             Map<Character, String> keys = parseMap(initialServerMessage.toString());
       
    78 
       
    79             // Extract the salt
       
    80             ByteArray salt = Base64.decode(keys.get('s'));
       
    81 
       
    82             // Extract the server nonce
       
    83             String clientServerNonce = keys.get('r');
       
    84             if (clientServerNonce.length() <= clientnonce.length()) {
       
    85                 return false;
       
    86             }
       
    87             String receivedClientNonce = clientServerNonce.substring(0, clientnonce.length());
       
    88             if (!receivedClientNonce.equals(clientnonce)) {
       
    89                 return false;
       
    90             }
       
    91             serverNonce = new ByteArray(clientServerNonce.substring(clientnonce.length()));
       
    92 
       
    93 
       
    94             // Extract the number of iterations
       
    95             int iterations = 0;
       
    96             try {
       
    97                 iterations = Integer.parseInt(keys.get('i'));
       
    98             } catch (NumberFormatException e) {
       
    99                 return false;
       
   100             }
       
   101             if (iterations <= 0) {
       
   102                 return false;
       
   103             }
       
   104 
       
   105             ByteArray channelBindData = new ByteArray();
       
   106             if (useChannelBinding && tlsChannelBindingData != null) {
       
   107                 channelBindData = tlsChannelBindingData;
       
   108             }
       
   109 
       
   110             // Compute all the values needed for the server signature
       
   111             saltedPassword = PBKDF2.encode(new ByteArray(SASLPrep(getPassword())), salt, iterations);
       
   112             authMessage = getInitialBareClientMessage().append(",").append(initialServerMessage).append(",").append(getFinalMessageWithoutProof());
       
   113             ByteArray serverKey = HMACSHA1.getResult(saltedPassword, new ByteArray("Server Key"));
       
   114             serverSignature = HMACSHA1.getResult(serverKey, authMessage);
       
   115 
       
   116             step = Step.Proof;
       
   117             return true;
       
   118         } else if (step.equals(step.Proof)) {
       
   119             ByteArray result = new ByteArray("v=").append(new ByteArray(Base64.encode(serverSignature)));
       
   120             step = Step.Final;
       
   121             return challenge != null && challenge.equals(result);
       
   122         } else {
       
   123             return true;
       
   124         }
       
   125     }
       
   126 
       
   127     private String SASLPrep(String source) {
       
   128         return Normalizer.normalize(source, Form.NFKC); /* FIXME: Implement real SASLPrep */
       
   129     }
       
   130 
       
   131     private Map<Character, String> parseMap(String s) {
       
   132         HashMap<Character, String> result = new HashMap<Character, String>();
       
   133         if (s.length() > 0) {
       
   134             char key = '~'; /* initialise so it'll compile */
       
   135             String value = "";
       
   136             int i = 0;
       
   137             boolean expectKey = true;
       
   138             while (i < s.length()) {
       
   139                 if (expectKey) {
       
   140                     key = s.charAt(i);
       
   141                     expectKey = false;
       
   142                     i++;
       
   143                 } else if (s.charAt(i) == ',') {
       
   144                     result.put(key, value);
       
   145                     value = "";
       
   146                     expectKey = true;
       
   147                 } else {
       
   148                     value += s.charAt(i);
       
   149                 }
       
   150                 i++;
       
   151             }
       
   152             result.put(key, value);
       
   153         }
       
   154         return result;
       
   155     }
       
   156 
       
   157     private ByteArray getInitialBareClientMessage() {
       
   158         String authenticationID = SASLPrep(getAuthenticationID());
       
   159         return new ByteArray("n=" + escape(authenticationID) + ",r=" + clientnonce);
       
   160     }
       
   161 
       
   162     private ByteArray getGS2Header() {
       
   163 
       
   164         ByteArray channelBindingHeader = new ByteArray("n");
       
   165 	if (tlsChannelBindingData != null) {
       
   166 		if (useChannelBinding) {
       
   167 			channelBindingHeader = new ByteArray("p=tls-unique");
       
   168 		}
       
   169 		else {
       
   170 			channelBindingHeader = new ByteArray("y");
       
   171 		}
       
   172 	}
       
   173 	return new ByteArray().append(channelBindingHeader).append(",").append(getAuthorizationID().isEmpty() ? new ByteArray() : new ByteArray("a=" + escape(getAuthorizationID()))).append(",");
       
   174     }
       
   175 
       
   176     private ByteArray getFinalMessageWithoutProof() {
       
   177         ByteArray channelBindData = new ByteArray();
       
   178 	if (useChannelBinding && tlsChannelBindingData != null) {
       
   179 		channelBindData = tlsChannelBindingData;
       
   180 	}
       
   181 	return new ByteArray("c=" + Base64.encode(new ByteArray(getGS2Header()).append(channelBindData)) + ",r=" + clientnonce).append(serverNonce);
       
   182     }
       
   183 
       
   184     private enum Step {
       
   185 
       
   186         Initial,
       
   187         Proof,
       
   188         Final
       
   189     };
       
   190     private Step step;
       
   191     private String clientnonce = "";
       
   192     private ByteArray initialServerMessage = new ByteArray();
       
   193     private ByteArray serverNonce = new ByteArray();
       
   194     private ByteArray authMessage = new ByteArray();
       
   195     private ByteArray saltedPassword = new ByteArray();
       
   196     private ByteArray serverSignature = new ByteArray();
       
   197     private boolean useChannelBinding;
       
   198     private ByteArray tlsChannelBindingData;
       
   199 }