Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 60 additions & 32 deletions src/wh_client_crypto.c
Original file line number Diff line number Diff line change
Expand Up @@ -5474,7 +5474,7 @@ int wh_Client_MlDsaImportKey(whClientContext* ctx, MlDsaKey* key,
{
int ret = WH_ERROR_OK;
whKeyId key_id = WH_KEYID_ERASED;
byte buffer[DILITHIUM_MAX_PRV_KEY_SIZE];
byte buffer[DILITHIUM_MAX_BOTH_KEY_DER_SIZE];
uint16_t buffer_len = 0;

if ((ctx == NULL) || (key == NULL) ||
Expand Down Expand Up @@ -5511,7 +5511,7 @@ int wh_Client_MlDsaExportKey(whClientContext* ctx, whKeyId keyId, MlDsaKey* key,
{
int ret = WH_ERROR_OK;
/* buffer cannot be larger than MTU */
byte buffer[DILITHIUM_MAX_PRV_KEY_SIZE];
byte buffer[DILITHIUM_MAX_BOTH_KEY_DER_SIZE];
uint16_t buffer_len = sizeof(buffer);

if ((ctx == NULL) || WH_KEYID_ISERASED(keyId) || (key == NULL)) {
Expand Down Expand Up @@ -5666,7 +5666,9 @@ int wh_Client_MlDsaMakeExportKey(whClientContext* ctx, int level, int size,


int wh_Client_MlDsaSign(whClientContext* ctx, const byte* in, word32 in_len,
byte* out, word32* inout_len, MlDsaKey* key)
byte* out, word32* inout_len, MlDsaKey* key,
const byte* context, byte contextLen,
word32 preHashType)
{
int ret = 0;
whMessageCrypto_MlDsaSignRequest* req = NULL;
Expand Down Expand Up @@ -5709,7 +5711,7 @@ int wh_Client_MlDsaSign(whClientContext* ctx, const byte* in, word32 in_len,
uint16_t action = WC_ALGO_TYPE_PK;

uint16_t req_len = sizeof(whMessageCrypto_GenericRequestHeader) +
sizeof(*req) + in_len;
sizeof(*req) + in_len + contextLen;
uint32_t options = 0;

/* Get data pointer from the context to use as request/response storage
Expand All @@ -5726,18 +5728,23 @@ int wh_Client_MlDsaSign(whClientContext* ctx, const byte* in, word32 in_len,
ctx->cryptoAffinity);

if (req_len <= WOLFHSM_CFG_COMM_DATA_LEN) {
uint8_t* req_hash = (uint8_t*)(req + 1);
uint8_t* req_data = (uint8_t*)(req + 1);
if (evict != 0) {
options |= WH_MESSAGE_CRYPTO_MLDSA_SIGN_OPTIONS_EVICT;
}

memset(req, 0, sizeof(*req));
req->options = options;
req->level = key->level;
req->keyId = key_id;
req->sz = in_len;
req->options = options;
req->level = key->level;
req->keyId = key_id;
req->sz = in_len;
req->contextSz = contextLen;
req->preHashType = preHashType;
if ((in != NULL) && (in_len > 0)) {
memcpy(req_hash, in, in_len);
memcpy(req_data, in, in_len);
}
if ((context != NULL) && (contextLen > 0)) {
memcpy(req_data + in_len, context, contextLen);
}

/* Send Request */
Expand Down Expand Up @@ -5794,8 +5801,9 @@ int wh_Client_MlDsaSign(whClientContext* ctx, const byte* in, word32 in_len,
}

int wh_Client_MlDsaVerify(whClientContext* ctx, const byte* sig, word32 sig_len,
const byte* msg, word32 msg_len, int* out_res,
MlDsaKey* key)
const byte* msg, word32 msg_len, int* out_res,
MlDsaKey* key, const byte* context, byte contextLen,
word32 preHashType)
{
int ret = WH_ERROR_OK;
uint8_t* dataPtr = NULL;
Expand Down Expand Up @@ -5838,7 +5846,7 @@ int wh_Client_MlDsaVerify(whClientContext* ctx, const byte* sig, word32 sig_len,
uint32_t options = 0;

uint16_t req_len = sizeof(whMessageCrypto_GenericRequestHeader) +
sizeof(*req) + sig_len + msg_len;
sizeof(*req) + sig_len + msg_len + contextLen;


/* Get data pointer from the context to use as request/response storage
Expand All @@ -5864,17 +5872,22 @@ int wh_Client_MlDsaVerify(whClientContext* ctx, const byte* sig, word32 sig_len,
}

memset(req, 0, sizeof(*req));
req->options = options;
req->level = key->level;
req->keyId = key_id;
req->sigSz = sig_len;
req->options = options;
req->level = key->level;
req->keyId = key_id;
req->sigSz = sig_len;
if ((sig != NULL) && (sig_len > 0)) {
memcpy(req_sig, sig, sig_len);
}
req->hashSz = msg_len;
req->hashSz = msg_len;
if ((msg != NULL) && (msg_len > 0)) {
memcpy(req_hash, msg, msg_len);
}
req->contextSz = contextLen;
req->preHashType = preHashType;
if ((context != NULL) && (contextLen > 0)) {
memcpy(req_hash + msg_len, context, contextLen);
}

/* write request */
ret = wh_Client_SendRequest(ctx, group, action, req_len,
Expand Down Expand Up @@ -5937,7 +5950,7 @@ int wh_Client_MlDsaImportKeyDma(whClientContext* ctx, MlDsaKey* key,
{
int ret = WH_ERROR_OK;
whKeyId key_id = WH_KEYID_ERASED;
byte buffer[DILITHIUM_MAX_PRV_KEY_SIZE];
byte buffer[DILITHIUM_MAX_BOTH_KEY_DER_SIZE];
uint16_t buffer_len = 0;

if ((ctx == NULL) || (key == NULL) ||
Expand Down Expand Up @@ -5969,7 +5982,7 @@ int wh_Client_MlDsaExportKeyDma(whClientContext* ctx, whKeyId keyId,
uint8_t* label)
{
int ret = WH_ERROR_OK;
byte buffer[DILITHIUM_MAX_PRV_KEY_SIZE] = {0};
byte buffer[DILITHIUM_MAX_BOTH_KEY_DER_SIZE] = {0};
uint16_t buffer_len = sizeof(buffer);

if ((ctx == NULL) || WH_KEYID_ISERASED(keyId) || (key == NULL)) {
Expand All @@ -5993,7 +6006,7 @@ static int _MlDsaMakeKeyDma(whClientContext* ctx, int level,
{
int ret = WH_ERROR_OK;
whKeyId key_id = WH_KEYID_ERASED;
byte buffer[DILITHIUM_MAX_PRV_KEY_SIZE];
byte buffer[DILITHIUM_MAX_BOTH_KEY_DER_SIZE];
uint8_t* dataPtr = NULL;
whMessageCrypto_MlDsaKeyGenDmaRequest* req = NULL;
whMessageCrypto_MlDsaKeyGenDmaResponse* res = NULL;
Expand Down Expand Up @@ -6116,7 +6129,9 @@ int wh_Client_MlDsaMakeExportKeyDma(whClientContext* ctx, int level,


int wh_Client_MlDsaSignDma(whClientContext* ctx, const byte* in, word32 in_len,
byte* out, word32* out_len, MlDsaKey* key)
byte* out, word32* out_len, MlDsaKey* key,
const byte* context, byte contextLen,
word32 preHashType)
{
int ret = 0;
whMessageCrypto_MlDsaSignDmaRequest* req = NULL;
Expand Down Expand Up @@ -6158,7 +6173,8 @@ int wh_Client_MlDsaSignDma(whClientContext* ctx, const byte* in, word32 in_len,
uint16_t action = WC_ALGO_TYPE_PK;

uint16_t req_len =
sizeof(whMessageCrypto_GenericRequestHeader) + sizeof(*req);
sizeof(whMessageCrypto_GenericRequestHeader) + sizeof(*req) +
contextLen;
uint32_t options = 0;

/* Get data pointer from the context to use as request/response storage
Expand All @@ -6180,9 +6196,14 @@ int wh_Client_MlDsaSignDma(whClientContext* ctx, const byte* in, word32 in_len,
}

memset(req, 0, sizeof(*req));
req->options = options;
req->level = key->level;
req->keyId = key_id;
req->options = options;
req->level = key->level;
req->keyId = key_id;
req->contextSz = contextLen;
req->preHashType = preHashType;
if ((context != NULL) && (contextLen > 0)) {
memcpy((uint8_t*)(req + 1), context, contextLen);
}

/* Set up DMA buffers */
req->msg.sz = in_len;
Expand Down Expand Up @@ -6255,8 +6276,9 @@ int wh_Client_MlDsaSignDma(whClientContext* ctx, const byte* in, word32 in_len,
}

int wh_Client_MlDsaVerifyDma(whClientContext* ctx, const byte* sig,
word32 sig_len, const byte* msg, word32 msg_len,
int* out_res, MlDsaKey* key)
word32 sig_len, const byte* msg, word32 msg_len,
int* out_res, MlDsaKey* key, const byte* context,
byte contextLen, word32 preHashType)
{
int ret = 0;
whMessageCrypto_MlDsaVerifyDmaRequest* req = NULL;
Expand Down Expand Up @@ -6296,7 +6318,8 @@ int wh_Client_MlDsaVerifyDma(whClientContext* ctx, const byte* sig,
uintptr_t msgAddr = 0;

uint16_t req_len =
sizeof(whMessageCrypto_GenericRequestHeader) + sizeof(*req);
sizeof(whMessageCrypto_GenericRequestHeader) + sizeof(*req) +
contextLen;

/* Get data pointer from the context to use as request/response storage
*/
Expand All @@ -6317,9 +6340,14 @@ int wh_Client_MlDsaVerifyDma(whClientContext* ctx, const byte* sig,
}

memset(req, 0, sizeof(*req));
req->options = options;
req->level = key->level;
req->keyId = key_id;
req->options = options;
req->level = key->level;
req->keyId = key_id;
req->contextSz = contextLen;
req->preHashType = preHashType;
if ((context != NULL) && (contextLen > 0)) {
memcpy((uint8_t*)(req + 1), context, contextLen);
}

/* Set up DMA buffers */
req->sig.sz = sig_len;
Expand Down
48 changes: 29 additions & 19 deletions src/wh_client_cryptocb.c
Original file line number Diff line number Diff line change
Expand Up @@ -642,12 +642,15 @@ static int _handlePqcSign(whClientContext* ctx, wc_CryptoInfo* info, int useDma)
int ret = CRYPTOCB_UNAVAILABLE;

/* Extract info parameters */
const byte* in = info->pk.pqc_sign.in;
word32 in_len = info->pk.pqc_sign.inlen;
byte* out = info->pk.pqc_sign.out;
word32* out_len = info->pk.pqc_sign.outlen;
void* key = info->pk.pqc_sign.key;
int type = info->pk.pqc_sign.type;
const byte* in = info->pk.pqc_sign.in;
word32 in_len = info->pk.pqc_sign.inlen;
byte* out = info->pk.pqc_sign.out;
word32* out_len = info->pk.pqc_sign.outlen;
void* key = info->pk.pqc_sign.key;
int type = info->pk.pqc_sign.type;
const byte* context = info->pk.pqc_sign.context;
byte contextLen = info->pk.pqc_sign.contextLen;
word32 preHashType = info->pk.pqc_sign.preHashType;

#ifndef WOLFHSM_CFG_DMA
if (useDma) {
Expand All @@ -661,13 +664,15 @@ static int _handlePqcSign(whClientContext* ctx, wc_CryptoInfo* info, int useDma)
case WC_PQC_SIG_TYPE_DILITHIUM:
#ifdef WOLFHSM_CFG_DMA
if (useDma) {
ret =
wh_Client_MlDsaSignDma(ctx, in, in_len, out, out_len, key);
ret = wh_Client_MlDsaSignDma(ctx, in, in_len, out, out_len,
key, context, contextLen,
preHashType);
}
else
#endif /* WOLFHSM_CFG_DMA */
{
ret = wh_Client_MlDsaSign(ctx, in, in_len, out, out_len, key);
ret = wh_Client_MlDsaSign(ctx, in, in_len, out, out_len, key,
context, contextLen, preHashType);
}
break;
#endif /* HAVE_DILITHIUM */
Expand All @@ -688,13 +693,16 @@ static int _handlePqcVerify(whClientContext* ctx, wc_CryptoInfo* info,
int ret = CRYPTOCB_UNAVAILABLE;

/* Extract info parameters */
const byte* sig = info->pk.pqc_verify.sig;
word32 sig_len = info->pk.pqc_verify.siglen;
const byte* msg = info->pk.pqc_verify.msg;
word32 msg_len = info->pk.pqc_verify.msglen;
int* res = info->pk.pqc_verify.res;
void* key = info->pk.pqc_verify.key;
int type = info->pk.pqc_verify.type;
const byte* sig = info->pk.pqc_verify.sig;
word32 sig_len = info->pk.pqc_verify.siglen;
const byte* msg = info->pk.pqc_verify.msg;
word32 msg_len = info->pk.pqc_verify.msglen;
int* res = info->pk.pqc_verify.res;
void* key = info->pk.pqc_verify.key;
int type = info->pk.pqc_verify.type;
const byte* context = info->pk.pqc_verify.context;
byte contextLen = info->pk.pqc_verify.contextLen;
word32 preHashType = info->pk.pqc_verify.preHashType;

#ifndef WOLFHSM_CFG_DMA
if (useDma) {
Expand All @@ -709,13 +717,15 @@ static int _handlePqcVerify(whClientContext* ctx, wc_CryptoInfo* info,
#ifdef WOLFHSM_CFG_DMA
if (useDma) {
ret = wh_Client_MlDsaVerifyDma(ctx, sig, sig_len, msg, msg_len,
res, key);
res, key, context, contextLen,
preHashType);
}
else
#endif /* WOLFHSM_CFG_DMA */
{
ret = wh_Client_MlDsaVerify(ctx, sig, sig_len, msg, msg_len, res,
key);
ret = wh_Client_MlDsaVerify(ctx, sig, sig_len, msg, msg_len,
res, key, context, contextLen,
preHashType);
}
break;
#endif /* HAVE_DILITHIUM */
Expand Down
8 changes: 8 additions & 0 deletions src/wh_message_crypto.c
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,8 @@ int wh_MessageCrypto_TranslateMlDsaSignRequest(
WH_T32(magic, dest, src, level);
WH_T32(magic, dest, src, keyId);
WH_T32(magic, dest, src, sz);
WH_T32(magic, dest, src, contextSz);
WH_T32(magic, dest, src, preHashType);
return 0;
}

Expand Down Expand Up @@ -816,6 +818,8 @@ int wh_MessageCrypto_TranslateMlDsaVerifyRequest(
WH_T32(magic, dest, src, keyId);
WH_T32(magic, dest, src, sigSz);
WH_T32(magic, dest, src, hashSz);
WH_T32(magic, dest, src, contextSz);
WH_T32(magic, dest, src, preHashType);
return 0;
}

Expand Down Expand Up @@ -1030,6 +1034,8 @@ int wh_MessageCrypto_TranslateMlDsaSignDmaRequest(
WH_T32(magic, dest, src, options);
WH_T32(magic, dest, src, level);
WH_T32(magic, dest, src, keyId);
WH_T32(magic, dest, src, contextSz);
WH_T32(magic, dest, src, preHashType);

return 0;
}
Expand Down Expand Up @@ -1079,6 +1085,8 @@ int wh_MessageCrypto_TranslateMlDsaVerifyDmaRequest(
WH_T32(magic, dest, src, options);
WH_T32(magic, dest, src, level);
WH_T32(magic, dest, src, keyId);
WH_T32(magic, dest, src, contextSz);
WH_T32(magic, dest, src, preHashType);

return 0;
}
Expand Down
Loading
Loading