Type safe Scala β€” Phantom types

Jun 16, 2017
functional-programming
scala

Boo! πŸ‘»

Phantom types parameters is one of my favourite features in type systems with parametric polymorphism : Java, Kotlin, Scala, C#, F#, Haskell, OCaml, and many others.

What is a phantom type parameter? A phantom type parameter is a type parameter that is checked statically by the compiler but not used at runtime. Let’s see an example:

final case class Key[KeyType](jkey: java.security.Key)

KeyType is a type parameter that it looks like is useless because we don’t need an instance of it to create a value of type Key.

What are they good for? They let you tag or mark types with additional information, which ultimately let you define compiler-checked invariants on what order some operations should follow, how they compose, or how they propagate the additional information.

Here’s an example on how to use phantom type parameters to check statically that we use a private key when decrypting a piece of information encrypted with a public key:

sealed trait PK  // public key type
sealed trait SK  // private key type
final case class Key   [KeyType](jkey: java.security.Key)
// carries key type info ^^^^^^
final case class Secret[KeyType](bytes: Array[Byte])
//                      ^^^^^^^
// carries key type used for encryption
def generateKeyPair: (Key[PK], Key[SK]) = {
  val gen = java.security.KeyPairGenerator.getInstance("RSA")
  val pair = gen.generateKeyPair
  (Key(pair.getPublic), Key(pair.getPrivate))
}
def encrypt(key: Key[PK], data: Array[Byte]): Secret[PK] = {
  val cipher = javax.crypto.Cipher.getInstance("RSA")
  cipher.init(javax.crypto.Cipher.ENCRYPT_MODE, key.jkey)
  Secret(cipher.doFinal(data))
}
def decrypt(key: Key[SK], secret: Secret[PK]): Array[Byte] = {
  val cipher = javax.crypto.Cipher.getInstance("RSA")
  cipher.init(javax.crypto.Cipher.DECRYPT_MODE, key.jkey)
  cipher.doFinal(secret.bytes)
}
val data     = "phantom".getBytes
val (pk, sk) = generateKeyPair
val secret = encrypt(pk, data)
decrypt(sk, encrypt(pk, data)).toVector == data.toVector // true
// if we make a mistake and try to decrypt with the encryption key
// we'll get a compiler error
//         vv--^^^^^^^^^^^^^^
// decrypt(pk, encrypt(pk, data)).toVector == data.toVector

We can further improve that example to contemplate the case where we encrypt with the private key and decrypt with the public key. I leave you the entire snippet of code as a gist because is a bit longer, notice however how an implicit parameter CanDecrypt is used to establish the encryption/decryption relation we need between the different keys.

@scala.annotation.implicitNotFound("${EncryptKeyType} encrypted messages cannot be decrypted with ${DecryptKeyType}")
trait CanDecrypt[DecryptKeyType, EncryptKeyType]
object CanDecrypt {
  implicit val pkOnSk = new CanDecrypt[PK, SK] {}
  implicit val skOnPk = new CanDecrypt[SK, PK] {}
}

sealed trait PK  // public key type
sealed trait SK  // private key type
final case class Key   [KeyType](jkey: java.security.Key)
// carries key type info ^^^^^^
final case class Secret[KeyType](bytes: Array[Byte])
//                      ^^^^^^^
// carries key type used for encryption

def generateKeyPair: (Key[PK], Key[SK]) = {
  val gen = java.security.KeyPairGenerator.getInstance("RSA")
  val pair = gen.generateKeyPair
  (Key(pair.getPublic), Key(pair.getPrivate))
}

def encrypt[A](key: Key[A], data: Array[Byte]): Secret[A] = {
  val cipher = javax.crypto.Cipher.getInstance("RSA")
  cipher.init(javax.crypto.Cipher.ENCRYPT_MODE, key.jkey)
  Secret(cipher.doFinal(data))
}

def decrypt[A, B](key: Key[A], secret: Secret[B])(implicit ev: CanDecrypt[A, B]): Array[Byte] = {
  val cipher = javax.crypto.Cipher.getInstance("RSA")
  cipher.init(javax.crypto.Cipher.DECRYPT_MODE, key.jkey)
  cipher.doFinal(secret.bytes)
}

val data     = "phantom".getBytes
val (pk, sk) = generateKeyPair
val secret = encrypt(pk, data)

decrypt(sk, encrypt(pk, data)).toVector == data.toVector // true
decrypt(pk, encrypt(sk, data)).toVector == data.toVector // true

// If we try to decrypt with the same encryption key:
// decrypt(sk, encrypt(sk, data)).toVector == data.toVector)
// we will get a compile error with the message:
// SK encrypted messages cannot be decrypted with SK