diff --git a/src/wh_client_crypto.c b/src/wh_client_crypto.c index 40bab2b3..7f5d2d35 100644 --- a/src/wh_client_crypto.c +++ b/src/wh_client_crypto.c @@ -5474,7 +5474,7 @@ int wh_Client_MlDsaImportKey(whClientContext* ctx, MlDsaKey* key, { int ret = WH_ERROR_OK; whKeyId key_id = WH_KEYID_ERASED; - byte buffer[DILITHIUM_MAX_PRV_KEY_SIZE]; + byte buffer[DILITHIUM_MAX_BOTH_KEY_DER_SIZE]; uint16_t buffer_len = 0; if ((ctx == NULL) || (key == NULL) || @@ -5511,7 +5511,7 @@ int wh_Client_MlDsaExportKey(whClientContext* ctx, whKeyId keyId, MlDsaKey* key, { int ret = WH_ERROR_OK; /* buffer cannot be larger than MTU */ - byte buffer[DILITHIUM_MAX_PRV_KEY_SIZE]; + byte buffer[DILITHIUM_MAX_BOTH_KEY_DER_SIZE]; uint16_t buffer_len = sizeof(buffer); if ((ctx == NULL) || WH_KEYID_ISERASED(keyId) || (key == NULL)) { @@ -5666,7 +5666,9 @@ int wh_Client_MlDsaMakeExportKey(whClientContext* ctx, int level, int size, int wh_Client_MlDsaSign(whClientContext* ctx, const byte* in, word32 in_len, - byte* out, word32* inout_len, MlDsaKey* key) + byte* out, word32* inout_len, MlDsaKey* key, + const byte* context, byte contextLen, + word32 preHashType) { int ret = 0; whMessageCrypto_MlDsaSignRequest* req = NULL; @@ -5709,7 +5711,7 @@ int wh_Client_MlDsaSign(whClientContext* ctx, const byte* in, word32 in_len, uint16_t action = WC_ALGO_TYPE_PK; uint16_t req_len = sizeof(whMessageCrypto_GenericRequestHeader) + - sizeof(*req) + in_len; + sizeof(*req) + in_len + contextLen; uint32_t options = 0; /* Get data pointer from the context to use as request/response storage @@ -5726,18 +5728,23 @@ int wh_Client_MlDsaSign(whClientContext* ctx, const byte* in, word32 in_len, ctx->cryptoAffinity); if (req_len <= WOLFHSM_CFG_COMM_DATA_LEN) { - uint8_t* req_hash = (uint8_t*)(req + 1); + uint8_t* req_data = (uint8_t*)(req + 1); if (evict != 0) { options |= WH_MESSAGE_CRYPTO_MLDSA_SIGN_OPTIONS_EVICT; } memset(req, 0, sizeof(*req)); - req->options = options; - req->level = key->level; - req->keyId = key_id; - req->sz = in_len; + req->options = options; + req->level = key->level; + req->keyId = key_id; + req->sz = in_len; + req->contextSz = contextLen; + req->preHashType = preHashType; if ((in != NULL) && (in_len > 0)) { - memcpy(req_hash, in, in_len); + memcpy(req_data, in, in_len); + } + if ((context != NULL) && (contextLen > 0)) { + memcpy(req_data + in_len, context, contextLen); } /* Send Request */ @@ -5794,8 +5801,9 @@ int wh_Client_MlDsaSign(whClientContext* ctx, const byte* in, word32 in_len, } int wh_Client_MlDsaVerify(whClientContext* ctx, const byte* sig, word32 sig_len, - const byte* msg, word32 msg_len, int* out_res, - MlDsaKey* key) + const byte* msg, word32 msg_len, int* out_res, + MlDsaKey* key, const byte* context, byte contextLen, + word32 preHashType) { int ret = WH_ERROR_OK; uint8_t* dataPtr = NULL; @@ -5838,7 +5846,7 @@ int wh_Client_MlDsaVerify(whClientContext* ctx, const byte* sig, word32 sig_len, uint32_t options = 0; uint16_t req_len = sizeof(whMessageCrypto_GenericRequestHeader) + - sizeof(*req) + sig_len + msg_len; + sizeof(*req) + sig_len + msg_len + contextLen; /* Get data pointer from the context to use as request/response storage @@ -5864,17 +5872,22 @@ int wh_Client_MlDsaVerify(whClientContext* ctx, const byte* sig, word32 sig_len, } memset(req, 0, sizeof(*req)); - req->options = options; - req->level = key->level; - req->keyId = key_id; - req->sigSz = sig_len; + req->options = options; + req->level = key->level; + req->keyId = key_id; + req->sigSz = sig_len; if ((sig != NULL) && (sig_len > 0)) { memcpy(req_sig, sig, sig_len); } - req->hashSz = msg_len; + req->hashSz = msg_len; if ((msg != NULL) && (msg_len > 0)) { memcpy(req_hash, msg, msg_len); } + req->contextSz = contextLen; + req->preHashType = preHashType; + if ((context != NULL) && (contextLen > 0)) { + memcpy(req_hash + msg_len, context, contextLen); + } /* write request */ ret = wh_Client_SendRequest(ctx, group, action, req_len, @@ -5937,7 +5950,7 @@ int wh_Client_MlDsaImportKeyDma(whClientContext* ctx, MlDsaKey* key, { int ret = WH_ERROR_OK; whKeyId key_id = WH_KEYID_ERASED; - byte buffer[DILITHIUM_MAX_PRV_KEY_SIZE]; + byte buffer[DILITHIUM_MAX_BOTH_KEY_DER_SIZE]; uint16_t buffer_len = 0; if ((ctx == NULL) || (key == NULL) || @@ -5969,7 +5982,7 @@ int wh_Client_MlDsaExportKeyDma(whClientContext* ctx, whKeyId keyId, uint8_t* label) { int ret = WH_ERROR_OK; - byte buffer[DILITHIUM_MAX_PRV_KEY_SIZE] = {0}; + byte buffer[DILITHIUM_MAX_BOTH_KEY_DER_SIZE] = {0}; uint16_t buffer_len = sizeof(buffer); if ((ctx == NULL) || WH_KEYID_ISERASED(keyId) || (key == NULL)) { @@ -5993,7 +6006,7 @@ static int _MlDsaMakeKeyDma(whClientContext* ctx, int level, { int ret = WH_ERROR_OK; whKeyId key_id = WH_KEYID_ERASED; - byte buffer[DILITHIUM_MAX_PRV_KEY_SIZE]; + byte buffer[DILITHIUM_MAX_BOTH_KEY_DER_SIZE]; uint8_t* dataPtr = NULL; whMessageCrypto_MlDsaKeyGenDmaRequest* req = NULL; whMessageCrypto_MlDsaKeyGenDmaResponse* res = NULL; @@ -6116,7 +6129,9 @@ int wh_Client_MlDsaMakeExportKeyDma(whClientContext* ctx, int level, int wh_Client_MlDsaSignDma(whClientContext* ctx, const byte* in, word32 in_len, - byte* out, word32* out_len, MlDsaKey* key) + byte* out, word32* out_len, MlDsaKey* key, + const byte* context, byte contextLen, + word32 preHashType) { int ret = 0; whMessageCrypto_MlDsaSignDmaRequest* req = NULL; @@ -6158,7 +6173,8 @@ int wh_Client_MlDsaSignDma(whClientContext* ctx, const byte* in, word32 in_len, uint16_t action = WC_ALGO_TYPE_PK; uint16_t req_len = - sizeof(whMessageCrypto_GenericRequestHeader) + sizeof(*req); + sizeof(whMessageCrypto_GenericRequestHeader) + sizeof(*req) + + contextLen; uint32_t options = 0; /* Get data pointer from the context to use as request/response storage @@ -6180,9 +6196,14 @@ int wh_Client_MlDsaSignDma(whClientContext* ctx, const byte* in, word32 in_len, } memset(req, 0, sizeof(*req)); - req->options = options; - req->level = key->level; - req->keyId = key_id; + req->options = options; + req->level = key->level; + req->keyId = key_id; + req->contextSz = contextLen; + req->preHashType = preHashType; + if ((context != NULL) && (contextLen > 0)) { + memcpy((uint8_t*)(req + 1), context, contextLen); + } /* Set up DMA buffers */ req->msg.sz = in_len; @@ -6255,8 +6276,9 @@ int wh_Client_MlDsaSignDma(whClientContext* ctx, const byte* in, word32 in_len, } int wh_Client_MlDsaVerifyDma(whClientContext* ctx, const byte* sig, - word32 sig_len, const byte* msg, word32 msg_len, - int* out_res, MlDsaKey* key) + word32 sig_len, const byte* msg, word32 msg_len, + int* out_res, MlDsaKey* key, const byte* context, + byte contextLen, word32 preHashType) { int ret = 0; whMessageCrypto_MlDsaVerifyDmaRequest* req = NULL; @@ -6296,7 +6318,8 @@ int wh_Client_MlDsaVerifyDma(whClientContext* ctx, const byte* sig, uintptr_t msgAddr = 0; uint16_t req_len = - sizeof(whMessageCrypto_GenericRequestHeader) + sizeof(*req); + sizeof(whMessageCrypto_GenericRequestHeader) + sizeof(*req) + + contextLen; /* Get data pointer from the context to use as request/response storage */ @@ -6317,9 +6340,14 @@ int wh_Client_MlDsaVerifyDma(whClientContext* ctx, const byte* sig, } memset(req, 0, sizeof(*req)); - req->options = options; - req->level = key->level; - req->keyId = key_id; + req->options = options; + req->level = key->level; + req->keyId = key_id; + req->contextSz = contextLen; + req->preHashType = preHashType; + if ((context != NULL) && (contextLen > 0)) { + memcpy((uint8_t*)(req + 1), context, contextLen); + } /* Set up DMA buffers */ req->sig.sz = sig_len; diff --git a/src/wh_client_cryptocb.c b/src/wh_client_cryptocb.c index 8037a592..43ff0a32 100644 --- a/src/wh_client_cryptocb.c +++ b/src/wh_client_cryptocb.c @@ -642,12 +642,15 @@ static int _handlePqcSign(whClientContext* ctx, wc_CryptoInfo* info, int useDma) int ret = CRYPTOCB_UNAVAILABLE; /* Extract info parameters */ - const byte* in = info->pk.pqc_sign.in; - word32 in_len = info->pk.pqc_sign.inlen; - byte* out = info->pk.pqc_sign.out; - word32* out_len = info->pk.pqc_sign.outlen; - void* key = info->pk.pqc_sign.key; - int type = info->pk.pqc_sign.type; + const byte* in = info->pk.pqc_sign.in; + word32 in_len = info->pk.pqc_sign.inlen; + byte* out = info->pk.pqc_sign.out; + word32* out_len = info->pk.pqc_sign.outlen; + void* key = info->pk.pqc_sign.key; + int type = info->pk.pqc_sign.type; + const byte* context = info->pk.pqc_sign.context; + byte contextLen = info->pk.pqc_sign.contextLen; + word32 preHashType = info->pk.pqc_sign.preHashType; #ifndef WOLFHSM_CFG_DMA if (useDma) { @@ -661,13 +664,15 @@ static int _handlePqcSign(whClientContext* ctx, wc_CryptoInfo* info, int useDma) case WC_PQC_SIG_TYPE_DILITHIUM: #ifdef WOLFHSM_CFG_DMA if (useDma) { - ret = - wh_Client_MlDsaSignDma(ctx, in, in_len, out, out_len, key); + ret = wh_Client_MlDsaSignDma(ctx, in, in_len, out, out_len, + key, context, contextLen, + preHashType); } else #endif /* WOLFHSM_CFG_DMA */ { - ret = wh_Client_MlDsaSign(ctx, in, in_len, out, out_len, key); + ret = wh_Client_MlDsaSign(ctx, in, in_len, out, out_len, key, + context, contextLen, preHashType); } break; #endif /* HAVE_DILITHIUM */ @@ -688,13 +693,16 @@ static int _handlePqcVerify(whClientContext* ctx, wc_CryptoInfo* info, int ret = CRYPTOCB_UNAVAILABLE; /* Extract info parameters */ - const byte* sig = info->pk.pqc_verify.sig; - word32 sig_len = info->pk.pqc_verify.siglen; - const byte* msg = info->pk.pqc_verify.msg; - word32 msg_len = info->pk.pqc_verify.msglen; - int* res = info->pk.pqc_verify.res; - void* key = info->pk.pqc_verify.key; - int type = info->pk.pqc_verify.type; + const byte* sig = info->pk.pqc_verify.sig; + word32 sig_len = info->pk.pqc_verify.siglen; + const byte* msg = info->pk.pqc_verify.msg; + word32 msg_len = info->pk.pqc_verify.msglen; + int* res = info->pk.pqc_verify.res; + void* key = info->pk.pqc_verify.key; + int type = info->pk.pqc_verify.type; + const byte* context = info->pk.pqc_verify.context; + byte contextLen = info->pk.pqc_verify.contextLen; + word32 preHashType = info->pk.pqc_verify.preHashType; #ifndef WOLFHSM_CFG_DMA if (useDma) { @@ -709,13 +717,15 @@ static int _handlePqcVerify(whClientContext* ctx, wc_CryptoInfo* info, #ifdef WOLFHSM_CFG_DMA if (useDma) { ret = wh_Client_MlDsaVerifyDma(ctx, sig, sig_len, msg, msg_len, - res, key); + res, key, context, contextLen, + preHashType); } else #endif /* WOLFHSM_CFG_DMA */ { - ret = wh_Client_MlDsaVerify(ctx, sig, sig_len, msg, msg_len, res, - key); + ret = wh_Client_MlDsaVerify(ctx, sig, sig_len, msg, msg_len, + res, key, context, contextLen, + preHashType); } break; #endif /* HAVE_DILITHIUM */ diff --git a/src/wh_message_crypto.c b/src/wh_message_crypto.c index 69ef2b87..83289caf 100644 --- a/src/wh_message_crypto.c +++ b/src/wh_message_crypto.c @@ -788,6 +788,8 @@ int wh_MessageCrypto_TranslateMlDsaSignRequest( WH_T32(magic, dest, src, level); WH_T32(magic, dest, src, keyId); WH_T32(magic, dest, src, sz); + WH_T32(magic, dest, src, contextSz); + WH_T32(magic, dest, src, preHashType); return 0; } @@ -816,6 +818,8 @@ int wh_MessageCrypto_TranslateMlDsaVerifyRequest( WH_T32(magic, dest, src, keyId); WH_T32(magic, dest, src, sigSz); WH_T32(magic, dest, src, hashSz); + WH_T32(magic, dest, src, contextSz); + WH_T32(magic, dest, src, preHashType); return 0; } @@ -1030,6 +1034,8 @@ int wh_MessageCrypto_TranslateMlDsaSignDmaRequest( WH_T32(magic, dest, src, options); WH_T32(magic, dest, src, level); WH_T32(magic, dest, src, keyId); + WH_T32(magic, dest, src, contextSz); + WH_T32(magic, dest, src, preHashType); return 0; } @@ -1079,6 +1085,8 @@ int wh_MessageCrypto_TranslateMlDsaVerifyDmaRequest( WH_T32(magic, dest, src, options); WH_T32(magic, dest, src, level); WH_T32(magic, dest, src, keyId); + WH_T32(magic, dest, src, contextSz); + WH_T32(magic, dest, src, preHashType); return 0; } diff --git a/src/wh_server_crypto.c b/src/wh_server_crypto.c index 6dfdb06f..6630c6c3 100644 --- a/src/wh_server_crypto.c +++ b/src/wh_server_crypto.c @@ -4159,12 +4159,14 @@ static int _HandleMlDsaSign(whServerContext* ctx, uint16_t magic, int devId, } /* Extract parameters from translated request */ - byte* in = (uint8_t*)(cryptoDataIn) + sizeof(whMessageCrypto_MlDsaSignRequest); - whKeyId key_id = wh_KeyId_TranslateFromClient( + byte* in = (uint8_t*)(cryptoDataIn) + sizeof(whMessageCrypto_MlDsaSignRequest); + whKeyId key_id = wh_KeyId_TranslateFromClient( WH_KEYTYPE_CRYPTO, ctx->comm->client_id, req.keyId); - word32 in_len = req.sz; - uint32_t options = req.options; - int evict = !!(options & WH_MESSAGE_CRYPTO_MLDSA_SIGN_OPTIONS_EVICT); + word32 in_len = req.sz; + uint32_t contextSz = req.contextSz; + uint32_t preHashType = req.preHashType; + uint32_t options = req.options; + int evict = !!(options & WH_MESSAGE_CRYPTO_MLDSA_SIGN_OPTIONS_EVICT); /* Validate key usage policy for signing */ if (!WH_KEYID_ISERASED(key_id)) { @@ -4184,6 +4186,13 @@ static int _HandleMlDsaSign(whServerContext* ctx, uint16_t magic, int devId, if (in_len > available_data) { return WH_ERROR_BADARGS; } + if (contextSz > (available_data - in_len)) { + return WH_ERROR_BADARGS; + } + if (contextSz > WH_CRYPTO_MLDSA_MAX_CTX_LEN) { + return WH_ERROR_BADARGS; + } + byte* req_context = (contextSz > 0) ? (in + in_len) : NULL; /* Response message */ byte* res_out = @@ -4198,9 +4207,17 @@ static int _HandleMlDsaSign(whServerContext* ctx, uint16_t magic, int devId, /* load the private key */ ret = wh_Server_MlDsaKeyCacheExport(ctx, key_id, key); if (ret == WH_ERROR_OK) { - /* sign the input */ - ret = wc_MlDsaKey_Sign(key, res_out, &res_len, in, in_len, - ctx->crypto->rng); + /* sign the input using appropriate FIPS 204 API */ + if (preHashType != WC_HASH_TYPE_NONE) { + ret = wc_MlDsaKey_SignCtxHash( + key, req_context, (byte)contextSz, res_out, &res_len, + in, in_len, preHashType, ctx->crypto->rng); + } + else { + ret = wc_MlDsaKey_SignCtx( + key, req_context, (byte)contextSz, res_out, &res_len, + in, in_len, ctx->crypto->rng); + } } wc_MlDsaKey_Free(key); } @@ -4234,8 +4251,6 @@ static int _HandleMlDsaVerify(whServerContext* ctx, uint16_t magic, int devId, (void)outSize; return WH_ERROR_NOHANDLER; #else - (void)inSize; - int ret; MlDsaKey key[1]; whMessageCrypto_MlDsaVerifyRequest req; @@ -4249,11 +4264,13 @@ static int _HandleMlDsaVerify(whServerContext* ctx, uint16_t magic, int devId, } /* Extract parameters from translated request */ - uint32_t options = req.options; - whKeyId key_id = wh_KeyId_TranslateFromClient( + uint32_t options = req.options; + whKeyId key_id = wh_KeyId_TranslateFromClient( WH_KEYTYPE_CRYPTO, ctx->comm->client_id, req.keyId); - uint32_t hash_len = req.hashSz; - uint32_t sig_len = req.sigSz; + uint32_t hash_len = req.hashSz; + uint32_t sig_len = req.sigSz; + uint32_t contextSz = req.contextSz; + uint32_t preHashType = req.preHashType; byte* req_sig = (uint8_t*)(cryptoDataIn) + sizeof(whMessageCrypto_MlDsaVerifyRequest); int evict = !!(options & WH_MESSAGE_CRYPTO_MLDSA_VERIFY_OPTIONS_EVICT); @@ -4276,8 +4293,15 @@ static int _HandleMlDsaVerify(whServerContext* ctx, uint16_t magic, int devId, (sig_len > (available - hash_len))) { return WH_ERROR_BADARGS; } + if (contextSz > (available - sig_len - hash_len)) { + return WH_ERROR_BADARGS; + } + if (contextSz > WH_CRYPTO_MLDSA_MAX_CTX_LEN) { + return WH_ERROR_BADARGS; + } - byte* req_hash = req_sig + sig_len; + byte* req_hash = req_sig + sig_len; + byte* req_context = (contextSz > 0) ? (req_hash + hash_len) : NULL; /* Response message */ int result = 0; @@ -4288,9 +4312,17 @@ static int _HandleMlDsaVerify(whServerContext* ctx, uint16_t magic, int devId, /* load the public key */ ret = wh_Server_MlDsaKeyCacheExport(ctx, key_id, key); if (ret == WH_ERROR_OK) { - /* verify the signature */ - ret = wc_MlDsaKey_Verify(key, req_sig, sig_len, req_hash, hash_len, - &result); + /* verify the signature using appropriate FIPS 204 API */ + if (preHashType != WC_HASH_TYPE_NONE) { + ret = wc_MlDsaKey_VerifyCtxHash( + key, req_sig, sig_len, req_context, (byte)contextSz, + req_hash, hash_len, preHashType, &result); + } + else { + ret = wc_MlDsaKey_VerifyCtx( + key, req_sig, sig_len, req_context, (byte)contextSz, + req_hash, hash_len, &result); + } } wc_MlDsaKey_Free(key); } @@ -5381,6 +5413,21 @@ static int _HandleMlDsaSignDma(whServerContext* ctx, uint16_t magic, int devId, ctx->comm->client_id, req.keyId); evict = !!(req.options & WH_MESSAGE_CRYPTO_MLDSA_SIGN_OPTIONS_EVICT); + /* Extract context from inline data after the struct */ + uint32_t contextSz = req.contextSz; + uint32_t preHashType = req.preHashType; + byte* req_context = NULL; + if (contextSz > WH_CRYPTO_MLDSA_MAX_CTX_LEN) { + return WH_ERROR_BADARGS; + } + if (contextSz > 0) { + if (inSize < sizeof(whMessageCrypto_MlDsaSignDmaRequest) + contextSz) { + return WH_ERROR_BADARGS; + } + req_context = (uint8_t*)(cryptoDataIn) + + sizeof(whMessageCrypto_MlDsaSignDmaRequest); + } + /* Initialize key */ ret = wc_MlDsaKey_Init(key, NULL, devId); if (ret == 0) { @@ -5400,10 +5447,20 @@ static int _HandleMlDsaSignDma(whServerContext* ctx, uint16_t magic, int devId, WH_DMA_OPER_CLIENT_WRITE_PRE, (whServerDmaFlags){0}); if (ret == 0) { - /* Sign the message */ + /* Sign the message using appropriate FIPS 204 API */ word32 sigLen = req.sig.sz; - ret = wc_MlDsaKey_Sign(key, sigAddr, &sigLen, msgAddr, - req.msg.sz, ctx->crypto->rng); + if (preHashType != WC_HASH_TYPE_NONE) { + ret = wc_MlDsaKey_SignCtxHash( + key, req_context, (byte)contextSz, + sigAddr, &sigLen, msgAddr, req.msg.sz, + preHashType, ctx->crypto->rng); + } + else { + ret = wc_MlDsaKey_SignCtx( + key, req_context, (byte)contextSz, + sigAddr, &sigLen, msgAddr, req.msg.sz, + ctx->crypto->rng); + } if (ret == 0) { /* Post-write processing of signature buffer */ @@ -5493,6 +5550,21 @@ static int _HandleMlDsaVerifyDma(whServerContext* ctx, uint16_t magic, ctx->comm->client_id, req.keyId); evict = !!(req.options & WH_MESSAGE_CRYPTO_MLDSA_VERIFY_OPTIONS_EVICT); + /* Extract context from inline data after the struct */ + uint32_t contextSz = req.contextSz; + uint32_t preHashType = req.preHashType; + byte* req_context = NULL; + if (contextSz > WH_CRYPTO_MLDSA_MAX_CTX_LEN) { + return WH_ERROR_BADARGS; + } + if (contextSz > 0) { + if (inSize < sizeof(whMessageCrypto_MlDsaVerifyDmaRequest) + contextSz) { + return WH_ERROR_BADARGS; + } + req_context = (uint8_t*)(cryptoDataIn) + + sizeof(whMessageCrypto_MlDsaVerifyDmaRequest); + } + /* Initialize key */ ret = wc_MlDsaKey_Init(key, NULL, devId); if (ret != 0) { @@ -5514,9 +5586,17 @@ static int _HandleMlDsaVerifyDma(whServerContext* ctx, uint16_t magic, WH_DMA_OPER_CLIENT_READ_PRE, (whServerDmaFlags){0}); if (ret == 0) { - /* Verify the signature */ - ret = wc_MlDsaKey_Verify(key, sigAddr, req.sig.sz, msgAddr, - req.msg.sz, &verified); + /* Verify the signature using appropriate FIPS 204 API */ + if (preHashType != WC_HASH_TYPE_NONE) { + ret = wc_MlDsaKey_VerifyCtxHash( + key, sigAddr, req.sig.sz, req_context, (byte)contextSz, + msgAddr, req.msg.sz, preHashType, &verified); + } + else { + ret = wc_MlDsaKey_VerifyCtx( + key, sigAddr, req.sig.sz, req_context, (byte)contextSz, + msgAddr, req.msg.sz, &verified); + } if (ret == 0) { /* Post-read processing of signature buffer */ diff --git a/test/config/wolfhsm_cfg.h b/test/config/wolfhsm_cfg.h index d0519a57..7e5a68f6 100644 --- a/test/config/wolfhsm_cfg.h +++ b/test/config/wolfhsm_cfg.h @@ -33,7 +33,7 @@ /* #define WOLFHSM_CFG_NO_CRYPTO */ /* #define WOLFHSM_CFG_SHE_EXTENSION */ -#define WOLFHSM_CFG_COMM_DATA_LEN (1280 * 4) +#define WOLFHSM_CFG_COMM_DATA_LEN (1024 * 8) /* Enable global keys feature for testing */ #define WOLFHSM_CFG_GLOBAL_KEYS diff --git a/test/wh_test_crypto.c b/test/wh_test_crypto.c index e72ae6a4..8709eed6 100644 --- a/test/wh_test_crypto.c +++ b/test/wh_test_crypto.c @@ -4270,6 +4270,128 @@ static int whTestCrypto_MlDsaWolfCrypt(whClientContext* ctx, int devId, return ret; } +static int whTestCrypto_MlDsaClient(whClientContext* ctx, int devId, + WC_RNG* rng) +{ + (void)devId; + (void)rng; + + int ret = 0; + MlDsaKey key[1]; + + /* Initialize key */ + ret = wc_MlDsaKey_Init(key, NULL, devId); + if (ret != 0) { + WH_ERROR_PRINT("Failed to initialize ML-DSA key: %d\n", ret); + return ret; + } + + /* Generate ephemeral key using non-DMA client API */ + if (ret == 0) { + ret = wh_Client_MlDsaMakeExportKey(ctx, WC_ML_DSA_44, 0, key); + if (ret != 0) { + WH_ERROR_PRINT("Failed to generate ML-DSA key: %d\n", ret); + } + } + + /* Test basic sign/verify using the public API (no context) */ + if (ret == 0) { + byte msg[] = "Test message for non-DMA ML-DSA"; + byte sig[DILITHIUM_MAX_SIG_SIZE]; + word32 sigLen = sizeof(sig); + int verified = 0; + + ret = wh_Client_MlDsaSign(ctx, msg, sizeof(msg), sig, &sigLen, key, + NULL, 0, WC_HASH_TYPE_NONE); + if (ret != 0) { + WH_ERROR_PRINT("Failed to sign using ML-DSA non-DMA: %d\n", ret); + } + else { + ret = wh_Client_MlDsaVerify(ctx, sig, sigLen, msg, sizeof(msg), + &verified, key, NULL, 0, + WC_HASH_TYPE_NONE); + if (ret != 0) { + WH_ERROR_PRINT("Failed to verify ML-DSA non-DMA: %d\n", ret); + } + else if (!verified) { + WH_ERROR_PRINT("ML-DSA non-DMA verification failed\n"); + ret = -1; + } + else { + /* Modify signature - should fail verification */ + sig[0] ^= 0xFF; + int vret = wh_Client_MlDsaVerify(ctx, sig, sigLen, msg, + sizeof(msg), &verified, key, + NULL, 0, WC_HASH_TYPE_NONE); + if (vret != 0) { + WH_ERROR_PRINT("Failed to call verify with modified sig: " + "%d\n", vret); + ret = vret; + } + else if (verified) { + WH_ERROR_PRINT("ML-DSA non-DMA verified bad signature\n"); + ret = -1; + } + } + } + } + + /* Test sign/verify with FIPS 204 context string */ + if (ret == 0) { + byte msg[] = "Context test message non-DMA"; + byte sig[DILITHIUM_MAX_SIG_SIZE]; + word32 sigLen = sizeof(sig); + int verified = 0; + const byte ctx_str[] = "test-context"; + + ret = wh_Client_MlDsaSign(ctx, msg, sizeof(msg), sig, &sigLen, key, + ctx_str, sizeof(ctx_str), + WC_HASH_TYPE_NONE); + if (ret != 0) { + WH_ERROR_PRINT("Failed to sign with context non-DMA: %d\n", ret); + } + else { + /* Verify with same context - should succeed */ + ret = wh_Client_MlDsaVerify(ctx, sig, sigLen, msg, sizeof(msg), + &verified, key, ctx_str, + sizeof(ctx_str), WC_HASH_TYPE_NONE); + if (ret != 0) { + WH_ERROR_PRINT("Failed to verify with context non-DMA: %d\n", + ret); + } + else if (!verified) { + WH_ERROR_PRINT("Context verification failed non-DMA\n"); + ret = -1; + } + else { + /* Verify with wrong context - should fail verification */ + const byte wrong_ctx[] = "wrong-context"; + int wrong_verified = 0; + int vret = wh_Client_MlDsaVerify( + ctx, sig, sigLen, msg, sizeof(msg), &wrong_verified, key, + wrong_ctx, sizeof(wrong_ctx), WC_HASH_TYPE_NONE); + if (vret != 0) { + WH_ERROR_PRINT("Failed to call verify with wrong context " + "non-DMA: %d\n", vret); + ret = vret; + } + else if (wrong_verified) { + WH_ERROR_PRINT("Verification succeeded with wrong context " + "non-DMA\n"); + ret = -1; + } + } + } + } + + if (ret == 0) { + WH_TEST_PRINT("ML-DSA Client Non-DMA API SUCCESS\n"); + } + + wc_MlDsaKey_Free(key); + return ret; +} + #ifdef WOLFHSM_CFG_DMA static int whTestCrypto_MlDsaDmaClient(whClientContext* ctx, int devId, WC_RNG* rng) @@ -4372,14 +4494,16 @@ static int whTestCrypto_MlDsaDmaClient(whClientContext* ctx, int devId, int verified = 0; /* Sign the message */ - ret = wh_Client_MlDsaSignDma(ctx, msg, sizeof(msg), sig, &sigLen, key); + ret = wh_Client_MlDsaSignDma(ctx, msg, sizeof(msg), sig, &sigLen, key, + NULL, 0, WC_HASH_TYPE_NONE); if (ret != 0) { WH_ERROR_PRINT("Failed to sign message using ML-DSA: %d\n", ret); } else { /* Verify the signature - should succeed */ ret = wh_Client_MlDsaVerifyDma(ctx, sig, sigLen, msg, sizeof(msg), - &verified, key); + &verified, key, NULL, 0, + WC_HASH_TYPE_NONE); if (ret != 0) { WH_ERROR_PRINT("Failed to verify signature using ML-DSA: %d\n", ret); @@ -4393,7 +4517,8 @@ static int whTestCrypto_MlDsaDmaClient(whClientContext* ctx, int devId, /* Modify signature and verify again - should fail */ sig[0] ^= 0xFF; ret = wh_Client_MlDsaVerifyDma(ctx, sig, sigLen, msg, - sizeof(msg), &verified, key); + sizeof(msg), &verified, key, + NULL, 0, WC_HASH_TYPE_NONE); if (ret != 0) { WH_ERROR_PRINT("Failed to verify modified signature using " "ML-DSA: %d\n", @@ -4412,6 +4537,58 @@ static int whTestCrypto_MlDsaDmaClient(whClientContext* ctx, int devId, } } + /* Test signing and verification with a FIPS 204 context string */ + if (ret == 0) { + byte msg[] = "Context test message"; + byte sig[DILITHIUM_MAX_SIG_SIZE]; + word32 sigLen = sizeof(sig); + int verified = 0; + const byte ctx_str[] = "test-context"; + + /* Sign with a context string */ + ret = wh_Client_MlDsaSignDma(ctx, msg, sizeof(msg), sig, &sigLen, + key, ctx_str, sizeof(ctx_str), + WC_HASH_TYPE_NONE); + if (ret != 0) { + WH_ERROR_PRINT("Failed to sign with context using ML-DSA: %d\n", + ret); + } + else { + /* Verify with the same context - should succeed */ + ret = wh_Client_MlDsaVerifyDma(ctx, sig, sigLen, msg, + sizeof(msg), &verified, key, + ctx_str, sizeof(ctx_str), + WC_HASH_TYPE_NONE); + if (ret != 0) { + WH_ERROR_PRINT("Failed to verify with context using ML-DSA: " + "%d\n", ret); + } + else if (!verified) { + WH_ERROR_PRINT("Context verification failed when it should " + "have succeeded\n"); + ret = -1; + } + else { + /* Verify with wrong context - should fail verification */ + const byte wrong_ctx[] = "wrong-context"; + int wrong_verified = 0; + int vret = wh_Client_MlDsaVerifyDma( + ctx, sig, sigLen, msg, sizeof(msg), &wrong_verified, key, + wrong_ctx, sizeof(wrong_ctx), WC_HASH_TYPE_NONE); + if (vret != 0) { + WH_ERROR_PRINT("Failed to call verify with wrong context: " + "%d\n", vret); + ret = vret; + } + else if (wrong_verified) { + WH_ERROR_PRINT("Context verification succeeded with wrong " + "context\n"); + ret = -1; + } + } + } + } + /* Clean up the cached key if it was imported */ if (keyImported) { int evict_ret = wh_Client_KeyEvict(ctx, keyId); @@ -5767,6 +5944,10 @@ int whTest_CryptoClientConfig(whClientConfig* config) } } + if (ret == 0) { + ret = whTestCrypto_MlDsaClient(client, WH_DEV_ID, rng); + } + #ifdef WOLFHSM_CFG_DMA if (ret == 0) { ret = whTestCrypto_MlDsaDmaClient(client, WH_DEV_ID_DMA, rng); diff --git a/test/wh_test_wolfcrypt_test.c b/test/wh_test_wolfcrypt_test.c index f1db0ca7..73947248 100644 --- a/test/wh_test_wolfcrypt_test.c +++ b/test/wh_test_wolfcrypt_test.c @@ -70,7 +70,7 @@ #endif -#define BUFFER_SIZE 4096 +#define BUFFER_SIZE (1024 * 8) #define FLASH_RAM_SIZE (1024 * 1024) /* 1MB */ #if defined(WOLFHSM_CFG_ENABLE_CLIENT) && !defined(NO_CRYPT_TEST) diff --git a/wolfhsm/wh_client_crypto.h b/wolfhsm/wh_client_crypto.h index 1cd02b3c..b09514a2 100644 --- a/wolfhsm/wh_client_crypto.h +++ b/wolfhsm/wh_client_crypto.h @@ -1124,11 +1124,17 @@ int wh_Client_MlDsaMakeCacheKey(whClientContext* ctx, int size, int level, * @param[in,out] out_len Pointer to size of output buffer, updated with actual * size. * @param[in] key Pointer to the ML-DSA key structure. + * @param[in] context Optional FIPS 204 context string for domain separation, + * or NULL for no context. + * @param[in] contextLen Length of the context string (max 255). + * @param[in] preHashType Hash type for HashML-DSA (e.g. WC_HASH_TYPE_SHA256), + * or WC_HASH_TYPE_NONE for pure ML-DSA. * @return int Returns 0 on success, or a negative error code on failure. */ int wh_Client_MlDsaSign(whClientContext* ctx, const byte* in, word32 in_len, - byte* out, word32* out_len, MlDsaKey* key); - + byte* out, word32* out_len, MlDsaKey* key, + const byte* context, byte contextLen, + word32 preHashType); /** * @brief Verify a ML-DSA signature. * @@ -1141,11 +1147,17 @@ int wh_Client_MlDsaSign(whClientContext* ctx, const byte* in, word32 in_len, * @param[in] msg_len Length of the message in bytes. * @param[out] res Pointer to store verification result (1=success, 0=failure). * @param[in] key Pointer to the ML-DSA key structure. + * @param[in] context Optional FIPS 204 context string for domain separation, + * or NULL for no context. + * @param[in] contextLen Length of the context string (max 255). + * @param[in] preHashType Hash type for HashML-DSA (e.g. WC_HASH_TYPE_SHA256), + * or WC_HASH_TYPE_NONE for pure ML-DSA. * @return int Returns 0 on success, or a negative error code on failure. */ -int wh_Client_MlDsaVerify(whClientContext* ctx, const byte* sig, word32 sig_len, - const byte* msg, word32 msg_len, int* res, - MlDsaKey* key); +int wh_Client_MlDsaVerify(whClientContext* ctx, const byte* sig, + word32 sig_len, const byte* msg, word32 msg_len, + int* res, MlDsaKey* key, const byte* context, + byte contextLen, word32 preHashType); /** * @brief Check a ML-DSA private key. @@ -1225,10 +1237,17 @@ int wh_Client_MlDsaMakeExportKeyDma(whClientContext* ctx, int level, * @param[in,out] out_len On input, size of out buffer. On output, length of * signature. * @param[in] key Pointer to the ML-DSA key structure. + * @param[in] context Optional FIPS 204 context string for domain separation, + * or NULL for no context. + * @param[in] contextLen Length of the context string (max 255). + * @param[in] preHashType Hash type for HashML-DSA (e.g. WC_HASH_TYPE_SHA256), + * or WC_HASH_TYPE_NONE for pure ML-DSA. * @return int Returns 0 on success, or a negative error code on failure. */ -int wh_Client_MlDsaSignDma(whClientContext* ctx, const byte* in, word32 in_len, - byte* out, word32* out_len, MlDsaKey* key); +int wh_Client_MlDsaSignDma(whClientContext* ctx, const byte* in, + word32 in_len, byte* out, word32* out_len, + MlDsaKey* key, const byte* context, + byte contextLen, word32 preHashType); /** * @brief Verify a ML-DSA signature with DMA. @@ -1242,11 +1261,18 @@ int wh_Client_MlDsaSignDma(whClientContext* ctx, const byte* in, word32 in_len, * @param[in] msg_len Length of the message in bytes. * @param[out] res Result of verification (1 = success, 0 = failure). * @param[in] key Pointer to the ML-DSA key structure. + * @param[in] context Optional FIPS 204 context string for domain separation, + * or NULL for no context. + * @param[in] contextLen Length of the context string (max 255). + * @param[in] preHashType Hash type for HashML-DSA (e.g. WC_HASH_TYPE_SHA256), + * or WC_HASH_TYPE_NONE for pure ML-DSA. * @return int Returns 0 on success, or a negative error code on failure. */ int wh_Client_MlDsaVerifyDma(whClientContext* ctx, const byte* sig, - word32 sig_len, const byte* msg, word32 msg_len, - int* res, MlDsaKey* key); + word32 sig_len, const byte* msg, + word32 msg_len, int* res, MlDsaKey* key, + const byte* context, byte contextLen, + word32 preHashType); /** * @brief Check a ML-DSA private key against public key with DMA. diff --git a/wolfhsm/wh_crypto.h b/wolfhsm/wh_crypto.h index 8ce20fd1..7254d783 100644 --- a/wolfhsm/wh_crypto.h +++ b/wolfhsm/wh_crypto.h @@ -109,6 +109,7 @@ int wh_Crypto_Ed25519DeserializeKeyDer(const uint8_t* buffer, uint16_t size, #endif /* HAVE_ED25519 */ #ifdef HAVE_DILITHIUM +#define WH_CRYPTO_MLDSA_MAX_CTX_LEN (255U) /* Store a MlDsaKey to a byte sequence */ int wh_Crypto_MlDsaSerializeKeyDer(MlDsaKey* key, uint16_t max_size, uint8_t* buffer, uint16_t* out_size); diff --git a/wolfhsm/wh_message_crypto.h b/wolfhsm/wh_message_crypto.h index 475aa371..6f33d89c 100644 --- a/wolfhsm/wh_message_crypto.h +++ b/wolfhsm/wh_message_crypto.h @@ -889,9 +889,11 @@ typedef struct { uint32_t level; uint32_t keyId; uint32_t sz; - uint8_t WH_PAD[4]; + uint32_t contextSz; /* FIPS 204 context length (0-255) */ + uint32_t preHashType; /* enum wc_HashType, 0 for pure ML-DSA */ /* Data follows: * uint8_t in[sz]; + * uint8_t context[contextSz]; */ } whMessageCrypto_MlDsaSignRequest; @@ -921,10 +923,13 @@ typedef struct { uint32_t keyId; uint32_t sigSz; uint32_t hashSz; + uint32_t contextSz; /* FIPS 204 context length (0-255) */ + uint32_t preHashType; /* enum wc_HashType, 0 for pure ML-DSA */ uint8_t WH_PAD[4]; /* Data follows: * uint8_t sig[sigSz]; * uint8_t hash[hashSz]; + * uint8_t context[contextSz]; */ } whMessageCrypto_MlDsaVerifyRequest; @@ -1165,12 +1170,17 @@ typedef struct { /* ML-DSA DMA Sign Request */ typedef struct { - whMessageCrypto_DmaBuffer msg; /* Message buffer */ - whMessageCrypto_DmaBuffer sig; /* Signature buffer */ - uint32_t options; /* Same options as non-DMA version */ - uint32_t level; /* ML-DSA security level */ - uint32_t keyId; /* Key ID to use for signing */ - uint8_t WH_PAD[4]; /* Pad to 8-byte alignment */ + whMessageCrypto_DmaBuffer msg; /* Message buffer */ + whMessageCrypto_DmaBuffer sig; /* Signature buffer */ + uint32_t options; /* Same options as non-DMA version */ + uint32_t level; /* ML-DSA security level */ + uint32_t keyId; /* Key ID to use for signing */ + uint32_t contextSz; /* FIPS 204 context length (0-255) */ + uint32_t preHashType; /* enum wc_HashType */ + uint8_t WH_PAD[4]; /* Pad to 8-byte alignment */ + /* Data follows: + * uint8_t context[contextSz]; + */ } whMessageCrypto_MlDsaSignDmaRequest; /* ML-DSA DMA Sign Response */ @@ -1182,12 +1192,17 @@ typedef struct { /* ML-DSA DMA Verify Request */ typedef struct { - whMessageCrypto_DmaBuffer sig; /* Signature buffer */ - whMessageCrypto_DmaBuffer msg; /* Message buffer */ - uint32_t options; /* Same options as non-DMA version */ - uint32_t level; /* ML-DSA security level */ - uint32_t keyId; /* Key ID to use for verification */ - uint8_t WH_PAD[4]; /* Pad to 8-byte alignment */ + whMessageCrypto_DmaBuffer sig; /* Signature buffer */ + whMessageCrypto_DmaBuffer msg; /* Message buffer */ + uint32_t options; /* Same options as non-DMA version */ + uint32_t level; /* ML-DSA security level */ + uint32_t keyId; /* Key ID to use for verification */ + uint32_t contextSz; /* FIPS 204 context length (0-255) */ + uint32_t preHashType; /* enum wc_HashType */ + uint8_t WH_PAD[4]; /* Pad to 8-byte alignment */ + /* Data follows: + * uint8_t context[contextSz]; + */ } whMessageCrypto_MlDsaVerifyDmaRequest; /* ML-DSA DMA Verify Response */