Type safe Scala β Phantom types
Boo! π»
Phantom types parameters is one of my favourite features in type systems with parametric polymorphism : Java, Kotlin, Scala, C#, F#, Haskell, OCaml, and many others.
What is a phantom type parameter? A phantom type parameter is a type parameter that is checked statically by the compiler but not used at runtime. Letβs see an example:
final case class Key[KeyType](jkey: java.security.Key)
KeyType
is a type parameter that it looks like is useless because we donβt need an instance of it to create a value of type Key
.
What are they good for? They let you tag or mark types with additional information, which ultimately let you define compiler-checked invariants on what order some operations should follow, how they compose, or how they propagate the additional information.
Hereβs an example on how to use phantom type parameters to check statically that we use a private key when decrypting a piece of information encrypted with a public key:
sealed trait PK // public key typesealed trait SK // private key typefinal case class Key [KeyType](jkey: java.security.Key)// carries key type info ^^^^^^final case class Secret[KeyType](bytes: Array[Byte])// ^^^^^^^// carries key type used for encryptiondef generateKeyPair: (Key[PK], Key[SK]) = { val gen = java.security.KeyPairGenerator.getInstance("RSA") val pair = gen.generateKeyPair (Key(pair.getPublic), Key(pair.getPrivate))}def encrypt(key: Key[PK], data: Array[Byte]): Secret[PK] = { val cipher = javax.crypto.Cipher.getInstance("RSA") cipher.init(javax.crypto.Cipher.ENCRYPT_MODE, key.jkey) Secret(cipher.doFinal(data))}def decrypt(key: Key[SK], secret: Secret[PK]): Array[Byte] = { val cipher = javax.crypto.Cipher.getInstance("RSA") cipher.init(javax.crypto.Cipher.DECRYPT_MODE, key.jkey) cipher.doFinal(secret.bytes)}val data = "phantom".getBytesval (pk, sk) = generateKeyPairval secret = encrypt(pk, data)decrypt(sk, encrypt(pk, data)).toVector == data.toVector // true// if we make a mistake and try to decrypt with the encryption key// we'll get a compiler error// vv--^^^^^^^^^^^^^^// decrypt(pk, encrypt(pk, data)).toVector == data.toVector
We can further improve that example to contemplate the case where we encrypt with the private key and decrypt with the public key. I leave you the entire snippet of code as a gist because is a bit longer, notice however how an implicit parameter CanDecrypt
is used to establish the encryption/decryption relation we need between the different keys.
@scala.annotation.implicitNotFound("${EncryptKeyType} encrypted messages cannot be decrypted with ${DecryptKeyType}")trait CanDecrypt[DecryptKeyType, EncryptKeyType]object CanDecrypt { implicit val pkOnSk = new CanDecrypt[PK, SK] {} implicit val skOnPk = new CanDecrypt[SK, PK] {}}
sealed trait PK // public key typesealed trait SK // private key typefinal case class Key [KeyType](jkey: java.security.Key)// carries key type info ^^^^^^final case class Secret[KeyType](bytes: Array[Byte])// ^^^^^^^// carries key type used for encryption
def generateKeyPair: (Key[PK], Key[SK]) = { val gen = java.security.KeyPairGenerator.getInstance("RSA") val pair = gen.generateKeyPair (Key(pair.getPublic), Key(pair.getPrivate))}
def encrypt[A](key: Key[A], data: Array[Byte]): Secret[A] = { val cipher = javax.crypto.Cipher.getInstance("RSA") cipher.init(javax.crypto.Cipher.ENCRYPT_MODE, key.jkey) Secret(cipher.doFinal(data))}
def decrypt[A, B](key: Key[A], secret: Secret[B])(implicit ev: CanDecrypt[A, B]): Array[Byte] = { val cipher = javax.crypto.Cipher.getInstance("RSA") cipher.init(javax.crypto.Cipher.DECRYPT_MODE, key.jkey) cipher.doFinal(secret.bytes)}
val data = "phantom".getBytesval (pk, sk) = generateKeyPairval secret = encrypt(pk, data)
decrypt(sk, encrypt(pk, data)).toVector == data.toVector // truedecrypt(pk, encrypt(sk, data)).toVector == data.toVector // true
// If we try to decrypt with the same encryption key:// decrypt(sk, encrypt(sk, data)).toVector == data.toVector)// we will get a compile error with the message:// SK encrypted messages cannot be decrypted with SK