diff --git a/lib/jwa/aes_cbc_hmac_sha2.js b/lib/jwa/aes_cbc_hmac_sha2.js index 71a34a2d..e5d1d6e5 100644 --- a/lib/jwa/aes_cbc_hmac_sha2.js +++ b/lib/jwa/aes_cbc_hmac_sha2.js @@ -4,8 +4,8 @@ const { strict: assert } = require('assert') const { TODO } = require('../errors') +// TODO: DRY const MAX_INT32 = Math.pow(2, 32) - const uint64be = (value, buf = Buffer.alloc(8)) => { const high = Math.floor(value / MAX_INT32) const low = value % MAX_INT32 diff --git a/lib/jwa/aes_kw.js b/lib/jwa/aes_kw.js new file mode 100644 index 00000000..6cf23bb0 --- /dev/null +++ b/lib/jwa/aes_kw.js @@ -0,0 +1,104 @@ +const { createCipheriv, createDecipheriv } = require('crypto') +const { strict: assert } = require('assert') + +const { TODO } = require('../errors') + +// TODO: DRY +const MAX_INT32 = Math.pow(2, 32) +const uint64be = (value, buf = Buffer.alloc(8)) => { + const high = Math.floor(value / MAX_INT32) + const low = value % MAX_INT32 + + buf.writeUInt32BE(high, 0) + buf.writeUInt32BE(low, 4) + return buf +} + +const A0 = Buffer.alloc(8, 'a6', 'hex') + +function xor (a, b) { + const len = Math.max(a.length, b.length) + const result = Buffer.alloc(len) + for (let idx = 0; len > idx; idx++) { + result[idx] = (a[idx] || 0) ^ (b[idx] || 0) + } + + return result +} + +function split (input, size) { + const output = [] + for (let idx = 0; input.length > idx; idx += size) { + output.push(input.slice(idx, idx + size)) + } + return output +} + +const wrapKey = (size, { keyObject }, payload) => { + // TODO: commonCheck + + const iv = Buffer.alloc(16) + let R = split(payload, 8) + let A + let B + let count + A = A0 + for (let jdx = 0; jdx < 6; jdx++) { + for (let idx = 0; R.length > idx; idx++) { + count = (R.length * jdx) + idx + 1 + const cipher = createCipheriv(`AES${size}`, keyObject, iv) + B = Buffer.concat([A, R[idx]]) + B = cipher.update(B) + + A = xor(B.slice(0, 8), uint64be(count)) + R[idx] = B.slice(8, 16) + } + } + R = [A].concat(R) + + return Buffer.concat(R) +} + +const unwrapKey = (size, { keyObject }, payload) => { + // TODO: commonCheck + + const iv = Buffer.alloc(16) + + let R = split(payload, 8) + let A + let B + let count + A = R[0] + R = R.slice(1) + for (let jdx = 5; jdx >= 0; --jdx) { + for (let idx = R.length - 1; idx >= 0; --idx) { + count = (R.length * jdx) + idx + 1 + B = xor(A, uint64be(count)) + B = Buffer.concat([B, R[idx], iv]) + const cipher = createDecipheriv(`AES${size}`, keyObject, iv) + B = cipher.update(B) + + A = B.slice(0, 8) + R[idx] = B.slice(8, 16) + } + } + + // TODO timingSafeEqual + if (A.toString() !== A0.toString()) { + throw new TODO('decryption failed') + } + + return Buffer.concat(R) +} + +module.exports = (JWA) => { + ['A128KW', 'A192KW', 'A256KW'].forEach((jwaAlg) => { + const size = parseInt(jwaAlg.substr(1, 3), 10) + + assert(!JWA.wrapKey.has(jwaAlg), `wrapKey alg ${jwaAlg} already registered`) + assert(!JWA.unwrapKey.has(jwaAlg), `unwrapKey alg ${jwaAlg} already registered`) + + JWA.wrapKey.set(jwaAlg, wrapKey.bind(undefined, size)) + JWA.unwrapKey.set(jwaAlg, unwrapKey.bind(undefined, size)) + }) +} diff --git a/lib/jwa/index.js b/lib/jwa/index.js index 174fe874..f0356060 100644 --- a/lib/jwa/index.js +++ b/lib/jwa/index.js @@ -15,13 +15,14 @@ require('./ecdsa')(JWA) require('./rsassa_pss')(JWA) require('./rsassa_pkcs1')(JWA) -// wrapKey, unwrapKey -require('./rsaes')(JWA) - // encrypt, decrypt require('./aes_cbc_hmac_sha2')(JWA) require('./aes_gcm')(JWA) + +// wrapKey, unwrapKey +require('./rsaes')(JWA) require('./aes_gcm_kw')(JWA) +require('./aes_kw')(JWA) module.exports = { sign: (alg, key, payload) => {