From mboxrd@z Thu Jan 1 00:00:00 1970 Return-Path: Received: from mails.dpdk.org (mails.dpdk.org [217.70.189.124]) by inbox.dpdk.org (Postfix) with ESMTP id 24C64429BA; Sun, 23 Apr 2023 07:08:41 +0200 (CEST) Received: from mails.dpdk.org (localhost [127.0.0.1]) by mails.dpdk.org (Postfix) with ESMTP id 79D1442D29; Sun, 23 Apr 2023 07:08:27 +0200 (CEST) Received: from mx0b-0016f401.pphosted.com (mx0b-0016f401.pphosted.com [67.231.156.173]) by mails.dpdk.org (Postfix) with ESMTP id 9AD6C42D0E for ; Sun, 23 Apr 2023 07:08:24 +0200 (CEST) Received: from pps.filterd (m0045851.ppops.net [127.0.0.1]) by mx0b-0016f401.pphosted.com (8.17.1.19/8.17.1.19) with ESMTP id 33N2AvmS020680 for ; Sat, 22 Apr 2023 22:08:24 -0700 DKIM-Signature: v=1; a=rsa-sha256; c=relaxed/relaxed; d=marvell.com; h=from : to : cc : subject : date : message-id : in-reply-to : references : mime-version : content-type; s=pfpt0220; bh=G+BPxCmf2UZt32VVZK7PTEYFaAgKg8lFeW6p2kEwDcY=; b=agNUA3LOfgG6pvuyUy/7UKuxKBjV34eNlySelTqsthUPjfH71jsyTs4FoKofVpef1Ii0 tUbJcX2o6iq8Vjne8tP/+A1PmwCswh9mU2L1D9Upi7KeBzw9VFUWCa3d3ZBM8x8f/3+f W4oNZoQkYCAP+84USYKo2WCRxer9EjuYdRs1i6nOiSTZ3WSCotEF0rdUN3Vk08f2+KlS JJBuEpOthFPCw6NRBlypargnl/4kxILKJcGeQ7T/cvTs7zxXh+zk3EOmt5kWbgsh6PDt xSKJTibl0sV512UFBZC/RAMPUxhxNa6lRPUMFO8TqWT5t41yueoZPHtSqjjVgCDQAWJe gQ== Received: from dc5-exch02.marvell.com ([199.233.59.182]) by mx0b-0016f401.pphosted.com (PPS) with ESMTPS id 3q4f3p2261-1 (version=TLSv1.2 cipher=ECDHE-RSA-AES256-SHA384 bits=256 verify=NOT) for ; Sat, 22 Apr 2023 22:08:23 -0700 Received: from DC5-EXCH01.marvell.com (10.69.176.38) by DC5-EXCH02.marvell.com (10.69.176.39) with Microsoft SMTP Server (TLS) id 15.0.1497.48; Sat, 22 Apr 2023 22:08:21 -0700 Received: from maili.marvell.com (10.69.176.80) by DC5-EXCH01.marvell.com (10.69.176.38) with Microsoft SMTP Server id 15.0.1497.48 via Frontend Transport; Sat, 22 Apr 2023 22:08:21 -0700 Received: from ml-host-33.caveonetworks.com (unknown [10.110.143.233]) by maili.marvell.com (Postfix) with ESMTP id 126793F704D; Sat, 22 Apr 2023 22:08:21 -0700 (PDT) From: Srikanth Yalavarthi To: Srikanth Yalavarthi CC: , , , Subject: [PATCH v1 3/3] ml/cnxk: add support for 32 I/O per model Date: Sat, 22 Apr 2023 22:08:14 -0700 Message-ID: <20230423050814.825-4-syalavarthi@marvell.com> X-Mailer: git-send-email 2.17.1 In-Reply-To: <20230423050814.825-1-syalavarthi@marvell.com> References: <20230423050814.825-1-syalavarthi@marvell.com> MIME-Version: 1.0 Content-Type: text/plain X-Proofpoint-ORIG-GUID: -GvP8x3BiDv6dYg-zDw6paNF71IyufwR X-Proofpoint-GUID: -GvP8x3BiDv6dYg-zDw6paNF71IyufwR X-Proofpoint-Virus-Version: vendor=baseguard engine=ICAP:2.0.254,Aquarius:18.0.942,Hydra:6.0.573,FMLib:17.11.170.22 definitions=2023-04-23_02,2023-04-21_01,2023-02-09_01 X-BeenThere: dev@dpdk.org X-Mailman-Version: 2.1.29 Precedence: list List-Id: DPDK patches and discussions List-Unsubscribe: , List-Archive: List-Post: List-Help: List-Subscribe: , Errors-To: dev-bounces@dpdk.org Added support for 32 inputs and outputs per model. Signed-off-by: Srikanth Yalavarthi --- drivers/ml/cnxk/cn10k_ml_model.c | 374 ++++++++++++++++++++++--------- drivers/ml/cnxk/cn10k_ml_model.h | 5 +- drivers/ml/cnxk/cn10k_ml_ops.c | 125 ++++++++--- 3 files changed, 367 insertions(+), 137 deletions(-) diff --git a/drivers/ml/cnxk/cn10k_ml_model.c b/drivers/ml/cnxk/cn10k_ml_model.c index a15df700aa..92c47d39ba 100644 --- a/drivers/ml/cnxk/cn10k_ml_model.c +++ b/drivers/ml/cnxk/cn10k_ml_model.c @@ -41,8 +41,9 @@ cn10k_ml_model_metadata_check(uint8_t *buffer, uint64_t size) struct cn10k_ml_model_metadata *metadata; uint32_t payload_crc32c; uint32_t header_crc32c; - uint8_t version[4]; + uint32_t version; uint8_t i; + uint8_t j; metadata = (struct cn10k_ml_model_metadata *)buffer; @@ -82,10 +83,13 @@ cn10k_ml_model_metadata_check(uint8_t *buffer, uint64_t size) } /* Header version */ - rte_memcpy(version, metadata->header.version, 4 * sizeof(uint8_t)); - if (version[0] * 1000 + version[1] * 100 != MRVL_ML_MODEL_VERSION_MIN) { - plt_err("Metadata version = %u.%u.%u.%u (< %u.%u.%u.%u) not supported", version[0], - version[1], version[2], version[3], (MRVL_ML_MODEL_VERSION_MIN / 1000) % 10, + version = metadata->header.version[0] * 1000 + metadata->header.version[1] * 100 + + metadata->header.version[2] * 10 + metadata->header.version[3]; + if (version < MRVL_ML_MODEL_VERSION_MIN) { + plt_err("Metadata version = %u.%u.%u.%u (< %u.%u.%u.%u) not supported", + metadata->header.version[0], metadata->header.version[1], + metadata->header.version[2], metadata->header.version[3], + (MRVL_ML_MODEL_VERSION_MIN / 1000) % 10, (MRVL_ML_MODEL_VERSION_MIN / 100) % 10, (MRVL_ML_MODEL_VERSION_MIN / 10) % 10, MRVL_ML_MODEL_VERSION_MIN % 10); return -ENOTSUP; @@ -125,60 +129,119 @@ cn10k_ml_model_metadata_check(uint8_t *buffer, uint64_t size) } /* Check input count */ - if (metadata->model.num_input > MRVL_ML_NUM_INPUT_OUTPUT_1) { - plt_err("Invalid metadata, num_input = %u (> %u)", metadata->model.num_input, - MRVL_ML_NUM_INPUT_OUTPUT_1); - return -EINVAL; - } - - /* Check output count */ - if (metadata->model.num_output > MRVL_ML_NUM_INPUT_OUTPUT_1) { - plt_err("Invalid metadata, num_output = %u (> %u)", metadata->model.num_output, - MRVL_ML_NUM_INPUT_OUTPUT_1); - return -EINVAL; - } - - /* Inputs */ - for (i = 0; i < metadata->model.num_input; i++) { - if (rte_ml_io_type_size_get(cn10k_ml_io_type_map(metadata->input1[i].input_type)) <= - 0) { - plt_err("Invalid metadata, input[%u] : input_type = %u", i, - metadata->input1[i].input_type); + if (version < 2301) { + if (metadata->model.num_input > MRVL_ML_NUM_INPUT_OUTPUT_1) { + plt_err("Invalid metadata, num_input = %u (> %u)", + metadata->model.num_input, MRVL_ML_NUM_INPUT_OUTPUT_1); return -EINVAL; } - if (rte_ml_io_type_size_get( - cn10k_ml_io_type_map(metadata->input1[i].model_input_type)) <= 0) { - plt_err("Invalid metadata, input[%u] : model_input_type = %u", i, - metadata->input1[i].model_input_type); + /* Check output count */ + if (metadata->model.num_output > MRVL_ML_NUM_INPUT_OUTPUT_1) { + plt_err("Invalid metadata, num_output = %u (> %u)", + metadata->model.num_output, MRVL_ML_NUM_INPUT_OUTPUT_1); return -EINVAL; } - - if (metadata->input1[i].relocatable != 1) { - plt_err("Model not supported, non-relocatable input: %u", i); - return -ENOTSUP; + } else { + if (metadata->model.num_input > MRVL_ML_NUM_INPUT_OUTPUT) { + plt_err("Invalid metadata, num_input = %u (> %u)", + metadata->model.num_input, MRVL_ML_NUM_INPUT_OUTPUT); + return -EINVAL; } - } - /* Outputs */ - for (i = 0; i < metadata->model.num_output; i++) { - if (rte_ml_io_type_size_get( - cn10k_ml_io_type_map(metadata->output1[i].output_type)) <= 0) { - plt_err("Invalid metadata, output[%u] : output_type = %u", i, - metadata->output1[i].output_type); + /* Check output count */ + if (metadata->model.num_output > MRVL_ML_NUM_INPUT_OUTPUT) { + plt_err("Invalid metadata, num_output = %u (> %u)", + metadata->model.num_output, MRVL_ML_NUM_INPUT_OUTPUT); return -EINVAL; } + } - if (rte_ml_io_type_size_get( - cn10k_ml_io_type_map(metadata->output1[i].model_output_type)) <= 0) { - plt_err("Invalid metadata, output[%u] : model_output_type = %u", i, - metadata->output1[i].model_output_type); - return -EINVAL; + /* Inputs */ + for (i = 0; i < metadata->model.num_input; i++) { + if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) { + if (rte_ml_io_type_size_get( + cn10k_ml_io_type_map(metadata->input1[i].input_type)) <= 0) { + plt_err("Invalid metadata, input1[%u] : input_type = %u", i, + metadata->input1[i].input_type); + return -EINVAL; + } + + if (rte_ml_io_type_size_get(cn10k_ml_io_type_map( + metadata->input1[i].model_input_type)) <= 0) { + plt_err("Invalid metadata, input1[%u] : model_input_type = %u", i, + metadata->input1[i].model_input_type); + return -EINVAL; + } + + if (metadata->input1[i].relocatable != 1) { + plt_err("Model not supported, non-relocatable input1: %u", i); + return -ENOTSUP; + } + } else { + j = i - MRVL_ML_NUM_INPUT_OUTPUT_1; + if (rte_ml_io_type_size_get( + cn10k_ml_io_type_map(metadata->input2[j].input_type)) <= 0) { + plt_err("Invalid metadata, input2[%u] : input_type = %u", j, + metadata->input2[j].input_type); + return -EINVAL; + } + + if (rte_ml_io_type_size_get(cn10k_ml_io_type_map( + metadata->input2[j].model_input_type)) <= 0) { + plt_err("Invalid metadata, input2[%u] : model_input_type = %u", j, + metadata->input2[j].model_input_type); + return -EINVAL; + } + + if (metadata->input2[j].relocatable != 1) { + plt_err("Model not supported, non-relocatable input2: %u", j); + return -ENOTSUP; + } } + } - if (metadata->output1[i].relocatable != 1) { - plt_err("Model not supported, non-relocatable output: %u", i); - return -ENOTSUP; + /* Outputs */ + for (i = 0; i < metadata->model.num_output; i++) { + if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) { + if (rte_ml_io_type_size_get( + cn10k_ml_io_type_map(metadata->output1[i].output_type)) <= 0) { + plt_err("Invalid metadata, output1[%u] : output_type = %u", i, + metadata->output1[i].output_type); + return -EINVAL; + } + + if (rte_ml_io_type_size_get(cn10k_ml_io_type_map( + metadata->output1[i].model_output_type)) <= 0) { + plt_err("Invalid metadata, output1[%u] : model_output_type = %u", i, + metadata->output1[i].model_output_type); + return -EINVAL; + } + + if (metadata->output1[i].relocatable != 1) { + plt_err("Model not supported, non-relocatable output1: %u", i); + return -ENOTSUP; + } + } else { + j = i - MRVL_ML_NUM_INPUT_OUTPUT_1; + if (rte_ml_io_type_size_get( + cn10k_ml_io_type_map(metadata->output2[j].output_type)) <= 0) { + plt_err("Invalid metadata, output2[%u] : output_type = %u", j, + metadata->output2[j].output_type); + return -EINVAL; + } + + if (rte_ml_io_type_size_get(cn10k_ml_io_type_map( + metadata->output2[j].model_output_type)) <= 0) { + plt_err("Invalid metadata, output2[%u] : model_output_type = %u", j, + metadata->output2[j].model_output_type); + return -EINVAL; + } + + if (metadata->output2[j].relocatable != 1) { + plt_err("Model not supported, non-relocatable output2: %u", j); + return -ENOTSUP; + } } } @@ -189,31 +252,60 @@ void cn10k_ml_model_metadata_update(struct cn10k_ml_model_metadata *metadata) { uint8_t i; + uint8_t j; for (i = 0; i < metadata->model.num_input; i++) { - metadata->input1[i].input_type = - cn10k_ml_io_type_map(metadata->input1[i].input_type); - metadata->input1[i].model_input_type = - cn10k_ml_io_type_map(metadata->input1[i].model_input_type); + if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) { + metadata->input1[i].input_type = + cn10k_ml_io_type_map(metadata->input1[i].input_type); + metadata->input1[i].model_input_type = + cn10k_ml_io_type_map(metadata->input1[i].model_input_type); + + if (metadata->input1[i].shape.w == 0) + metadata->input1[i].shape.w = 1; + + if (metadata->input1[i].shape.x == 0) + metadata->input1[i].shape.x = 1; + + if (metadata->input1[i].shape.y == 0) + metadata->input1[i].shape.y = 1; - if (metadata->input1[i].shape.w == 0) - metadata->input1[i].shape.w = 1; + if (metadata->input1[i].shape.z == 0) + metadata->input1[i].shape.z = 1; + } else { + j = i - MRVL_ML_NUM_INPUT_OUTPUT_1; + metadata->input2[j].input_type = + cn10k_ml_io_type_map(metadata->input2[j].input_type); + metadata->input2[j].model_input_type = + cn10k_ml_io_type_map(metadata->input2[j].model_input_type); - if (metadata->input1[i].shape.x == 0) - metadata->input1[i].shape.x = 1; + if (metadata->input2[j].shape.w == 0) + metadata->input2[j].shape.w = 1; - if (metadata->input1[i].shape.y == 0) - metadata->input1[i].shape.y = 1; + if (metadata->input2[j].shape.x == 0) + metadata->input2[j].shape.x = 1; - if (metadata->input1[i].shape.z == 0) - metadata->input1[i].shape.z = 1; + if (metadata->input2[j].shape.y == 0) + metadata->input2[j].shape.y = 1; + + if (metadata->input2[j].shape.z == 0) + metadata->input2[j].shape.z = 1; + } } for (i = 0; i < metadata->model.num_output; i++) { - metadata->output1[i].output_type = - cn10k_ml_io_type_map(metadata->output1[i].output_type); - metadata->output1[i].model_output_type = - cn10k_ml_io_type_map(metadata->output1[i].model_output_type); + if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) { + metadata->output1[i].output_type = + cn10k_ml_io_type_map(metadata->output1[i].output_type); + metadata->output1[i].model_output_type = + cn10k_ml_io_type_map(metadata->output1[i].model_output_type); + } else { + j = i - MRVL_ML_NUM_INPUT_OUTPUT_1; + metadata->output2[j].output_type = + cn10k_ml_io_type_map(metadata->output2[j].output_type); + metadata->output2[j].model_output_type = + cn10k_ml_io_type_map(metadata->output2[j].model_output_type); + } } } @@ -226,6 +318,7 @@ cn10k_ml_model_addr_update(struct cn10k_ml_model *model, uint8_t *buffer, uint8_ uint8_t *dma_addr_load; uint8_t *dma_addr_run; uint8_t i; + uint8_t j; int fpos; metadata = &model->metadata; @@ -272,37 +365,80 @@ cn10k_ml_model_addr_update(struct cn10k_ml_model *model, uint8_t *buffer, uint8_ addr->total_input_sz_d = 0; addr->total_input_sz_q = 0; for (i = 0; i < metadata->model.num_input; i++) { - addr->input[i].nb_elements = - metadata->input1[i].shape.w * metadata->input1[i].shape.x * - metadata->input1[i].shape.y * metadata->input1[i].shape.z; - addr->input[i].sz_d = addr->input[i].nb_elements * - rte_ml_io_type_size_get(metadata->input1[i].input_type); - addr->input[i].sz_q = addr->input[i].nb_elements * - rte_ml_io_type_size_get(metadata->input1[i].model_input_type); - addr->total_input_sz_d += addr->input[i].sz_d; - addr->total_input_sz_q += addr->input[i].sz_q; - - plt_ml_dbg("model_id = %u, input[%u] - w:%u x:%u y:%u z:%u, sz_d = %u sz_q = %u", - model->model_id, i, metadata->input1[i].shape.w, - metadata->input1[i].shape.x, metadata->input1[i].shape.y, - metadata->input1[i].shape.z, addr->input[i].sz_d, addr->input[i].sz_q); + if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) { + addr->input[i].nb_elements = + metadata->input1[i].shape.w * metadata->input1[i].shape.x * + metadata->input1[i].shape.y * metadata->input1[i].shape.z; + addr->input[i].sz_d = + addr->input[i].nb_elements * + rte_ml_io_type_size_get(metadata->input1[i].input_type); + addr->input[i].sz_q = + addr->input[i].nb_elements * + rte_ml_io_type_size_get(metadata->input1[i].model_input_type); + addr->total_input_sz_d += addr->input[i].sz_d; + addr->total_input_sz_q += addr->input[i].sz_q; + + plt_ml_dbg( + "model_id = %u, input[%u] - w:%u x:%u y:%u z:%u, sz_d = %u sz_q = %u", + model->model_id, i, metadata->input1[i].shape.w, + metadata->input1[i].shape.x, metadata->input1[i].shape.y, + metadata->input1[i].shape.z, addr->input[i].sz_d, + addr->input[i].sz_q); + } else { + j = i - MRVL_ML_NUM_INPUT_OUTPUT_1; + addr->input[i].nb_elements = + metadata->input2[j].shape.w * metadata->input2[j].shape.x * + metadata->input2[j].shape.y * metadata->input2[j].shape.z; + addr->input[i].sz_d = + addr->input[i].nb_elements * + rte_ml_io_type_size_get(metadata->input2[j].input_type); + addr->input[i].sz_q = + addr->input[i].nb_elements * + rte_ml_io_type_size_get(metadata->input2[j].model_input_type); + addr->total_input_sz_d += addr->input[i].sz_d; + addr->total_input_sz_q += addr->input[i].sz_q; + + plt_ml_dbg( + "model_id = %u, input2[%u] - w:%u x:%u y:%u z:%u, sz_d = %u sz_q = %u", + model->model_id, j, metadata->input2[j].shape.w, + metadata->input2[j].shape.x, metadata->input2[j].shape.y, + metadata->input2[j].shape.z, addr->input[i].sz_d, + addr->input[i].sz_q); + } } /* Outputs */ addr->total_output_sz_q = 0; addr->total_output_sz_d = 0; for (i = 0; i < metadata->model.num_output; i++) { - addr->output[i].nb_elements = metadata->output1[i].size; - addr->output[i].sz_d = addr->output[i].nb_elements * - rte_ml_io_type_size_get(metadata->output1[i].output_type); - addr->output[i].sz_q = - addr->output[i].nb_elements * - rte_ml_io_type_size_get(metadata->output1[i].model_output_type); - addr->total_output_sz_q += addr->output[i].sz_q; - addr->total_output_sz_d += addr->output[i].sz_d; - - plt_ml_dbg("model_id = %u, output[%u] - sz_d = %u, sz_q = %u", model->model_id, i, - addr->output[i].sz_d, addr->output[i].sz_q); + if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) { + addr->output[i].nb_elements = metadata->output1[i].size; + addr->output[i].sz_d = + addr->output[i].nb_elements * + rte_ml_io_type_size_get(metadata->output1[i].output_type); + addr->output[i].sz_q = + addr->output[i].nb_elements * + rte_ml_io_type_size_get(metadata->output1[i].model_output_type); + addr->total_output_sz_q += addr->output[i].sz_q; + addr->total_output_sz_d += addr->output[i].sz_d; + + plt_ml_dbg("model_id = %u, output[%u] - sz_d = %u, sz_q = %u", + model->model_id, i, addr->output[i].sz_d, addr->output[i].sz_q); + } else { + j = i - MRVL_ML_NUM_INPUT_OUTPUT_1; + addr->output[i].nb_elements = metadata->output2[j].size; + addr->output[i].sz_d = + addr->output[i].nb_elements * + rte_ml_io_type_size_get(metadata->output2[j].output_type); + addr->output[i].sz_q = + addr->output[i].nb_elements * + rte_ml_io_type_size_get(metadata->output2[j].model_output_type); + addr->total_output_sz_q += addr->output[i].sz_q; + addr->total_output_sz_d += addr->output[i].sz_d; + + plt_ml_dbg("model_id = %u, output2[%u] - sz_d = %u, sz_q = %u", + model->model_id, j, addr->output[i].sz_d, addr->output[i].sz_q); + } } } @@ -366,6 +502,7 @@ cn10k_ml_model_info_set(struct rte_ml_dev *dev, struct cn10k_ml_model *model) struct rte_ml_io_info *output; struct rte_ml_io_info *input; uint8_t i; + uint8_t j; metadata = &model->metadata; info = PLT_PTR_CAST(model->info); @@ -389,26 +526,53 @@ cn10k_ml_model_info_set(struct rte_ml_dev *dev, struct cn10k_ml_model *model) /* Set input info */ for (i = 0; i < info->nb_inputs; i++) { - rte_memcpy(input[i].name, metadata->input1[i].input_name, MRVL_ML_INPUT_NAME_LEN); - input[i].dtype = metadata->input1[i].input_type; - input[i].qtype = metadata->input1[i].model_input_type; - input[i].shape.format = metadata->input1[i].shape.format; - input[i].shape.w = metadata->input1[i].shape.w; - input[i].shape.x = metadata->input1[i].shape.x; - input[i].shape.y = metadata->input1[i].shape.y; - input[i].shape.z = metadata->input1[i].shape.z; + if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) { + rte_memcpy(input[i].name, metadata->input1[i].input_name, + MRVL_ML_INPUT_NAME_LEN); + input[i].dtype = metadata->input1[i].input_type; + input[i].qtype = metadata->input1[i].model_input_type; + input[i].shape.format = metadata->input1[i].shape.format; + input[i].shape.w = metadata->input1[i].shape.w; + input[i].shape.x = metadata->input1[i].shape.x; + input[i].shape.y = metadata->input1[i].shape.y; + input[i].shape.z = metadata->input1[i].shape.z; + } else { + j = i - MRVL_ML_NUM_INPUT_OUTPUT_1; + rte_memcpy(input[i].name, metadata->input2[j].input_name, + MRVL_ML_INPUT_NAME_LEN); + input[i].dtype = metadata->input2[j].input_type; + input[i].qtype = metadata->input2[j].model_input_type; + input[i].shape.format = metadata->input2[j].shape.format; + input[i].shape.w = metadata->input2[j].shape.w; + input[i].shape.x = metadata->input2[j].shape.x; + input[i].shape.y = metadata->input2[j].shape.y; + input[i].shape.z = metadata->input2[j].shape.z; + } } /* Set output info */ for (i = 0; i < info->nb_outputs; i++) { - rte_memcpy(output[i].name, metadata->output1[i].output_name, - MRVL_ML_OUTPUT_NAME_LEN); - output[i].dtype = metadata->output1[i].output_type; - output[i].qtype = metadata->output1[i].model_output_type; - output[i].shape.format = RTE_ML_IO_FORMAT_1D; - output[i].shape.w = metadata->output1[i].size; - output[i].shape.x = 1; - output[i].shape.y = 1; - output[i].shape.z = 1; + if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) { + rte_memcpy(output[i].name, metadata->output1[i].output_name, + MRVL_ML_OUTPUT_NAME_LEN); + output[i].dtype = metadata->output1[i].output_type; + output[i].qtype = metadata->output1[i].model_output_type; + output[i].shape.format = RTE_ML_IO_FORMAT_1D; + output[i].shape.w = metadata->output1[i].size; + output[i].shape.x = 1; + output[i].shape.y = 1; + output[i].shape.z = 1; + } else { + j = i - MRVL_ML_NUM_INPUT_OUTPUT_1; + rte_memcpy(output[i].name, metadata->output2[j].output_name, + MRVL_ML_OUTPUT_NAME_LEN); + output[i].dtype = metadata->output2[j].output_type; + output[i].qtype = metadata->output2[j].model_output_type; + output[i].shape.format = RTE_ML_IO_FORMAT_1D; + output[i].shape.w = metadata->output2[j].size; + output[i].shape.x = 1; + output[i].shape.y = 1; + output[i].shape.z = 1; + } } } diff --git a/drivers/ml/cnxk/cn10k_ml_model.h b/drivers/ml/cnxk/cn10k_ml_model.h index bd863a8c12..5c34e4d747 100644 --- a/drivers/ml/cnxk/cn10k_ml_model.h +++ b/drivers/ml/cnxk/cn10k_ml_model.h @@ -30,6 +30,7 @@ enum cn10k_ml_model_state { #define MRVL_ML_OUTPUT_NAME_LEN 16 #define MRVL_ML_NUM_INPUT_OUTPUT_1 8 #define MRVL_ML_NUM_INPUT_OUTPUT_2 24 +#define MRVL_ML_NUM_INPUT_OUTPUT (MRVL_ML_NUM_INPUT_OUTPUT_1 + MRVL_ML_NUM_INPUT_OUTPUT_2) /* Header (256-byte) */ struct cn10k_ml_model_metadata_header { @@ -413,7 +414,7 @@ struct cn10k_ml_model_addr { /* Quantized input size */ uint32_t sz_q; - } input[MRVL_ML_NUM_INPUT_OUTPUT_1]; + } input[MRVL_ML_NUM_INPUT_OUTPUT]; /* Output address and size */ struct { @@ -425,7 +426,7 @@ struct cn10k_ml_model_addr { /* Quantized output size */ uint32_t sz_q; - } output[MRVL_ML_NUM_INPUT_OUTPUT_1]; + } output[MRVL_ML_NUM_INPUT_OUTPUT]; /* Total size of quantized input */ uint32_t total_input_sz_q; diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c index aecc6e74ad..1033afb1b0 100644 --- a/drivers/ml/cnxk/cn10k_ml_ops.c +++ b/drivers/ml/cnxk/cn10k_ml_ops.c @@ -269,6 +269,7 @@ cn10k_ml_model_print(struct rte_ml_dev *dev, uint16_t model_id, FILE *fp) struct cn10k_ml_ocm *ocm; char str[STR_LEN]; uint8_t i; + uint8_t j; mldev = dev->data->dev_private; ocm = &mldev->ocm; @@ -324,16 +325,36 @@ cn10k_ml_model_print(struct rte_ml_dev *dev, uint16_t model_id, FILE *fp) "model_input_type", "quantize", "format"); print_line(fp, LINE_LEN); for (i = 0; i < model->metadata.model.num_input; i++) { - fprintf(fp, "%8u ", i); - fprintf(fp, "%*s ", 16, model->metadata.input1[i].input_name); - rte_ml_io_type_to_str(model->metadata.input1[i].input_type, str, STR_LEN); - fprintf(fp, "%*s ", 12, str); - rte_ml_io_type_to_str(model->metadata.input1[i].model_input_type, str, STR_LEN); - fprintf(fp, "%*s ", 18, str); - fprintf(fp, "%*s", 12, (model->metadata.input1[i].quantize == 1 ? "Yes" : "No")); - rte_ml_io_format_to_str(model->metadata.input1[i].shape.format, str, STR_LEN); - fprintf(fp, "%*s", 16, str); - fprintf(fp, "\n"); + if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) { + fprintf(fp, "%8u ", i); + fprintf(fp, "%*s ", 16, model->metadata.input1[i].input_name); + rte_ml_io_type_to_str(model->metadata.input1[i].input_type, str, STR_LEN); + fprintf(fp, "%*s ", 12, str); + rte_ml_io_type_to_str(model->metadata.input1[i].model_input_type, str, + STR_LEN); + fprintf(fp, "%*s ", 18, str); + fprintf(fp, "%*s", 12, + (model->metadata.input1[i].quantize == 1 ? "Yes" : "No")); + rte_ml_io_format_to_str(model->metadata.input1[i].shape.format, str, + STR_LEN); + fprintf(fp, "%*s", 16, str); + fprintf(fp, "\n"); + } else { + j = i - MRVL_ML_NUM_INPUT_OUTPUT_1; + fprintf(fp, "%8u ", i); + fprintf(fp, "%*s ", 16, model->metadata.input2[j].input_name); + rte_ml_io_type_to_str(model->metadata.input2[j].input_type, str, STR_LEN); + fprintf(fp, "%*s ", 12, str); + rte_ml_io_type_to_str(model->metadata.input2[j].model_input_type, str, + STR_LEN); + fprintf(fp, "%*s ", 18, str); + fprintf(fp, "%*s", 12, + (model->metadata.input2[j].quantize == 1 ? "Yes" : "No")); + rte_ml_io_format_to_str(model->metadata.input2[j].shape.format, str, + STR_LEN); + fprintf(fp, "%*s", 16, str); + fprintf(fp, "\n"); + } } fprintf(fp, "\n"); @@ -342,14 +363,30 @@ cn10k_ml_model_print(struct rte_ml_dev *dev, uint16_t model_id, FILE *fp) "model_output_type", "dequantize"); print_line(fp, LINE_LEN); for (i = 0; i < model->metadata.model.num_output; i++) { - fprintf(fp, "%8u ", i); - fprintf(fp, "%*s ", 16, model->metadata.output1[i].output_name); - rte_ml_io_type_to_str(model->metadata.output1[i].output_type, str, STR_LEN); - fprintf(fp, "%*s ", 12, str); - rte_ml_io_type_to_str(model->metadata.output1[i].model_output_type, str, STR_LEN); - fprintf(fp, "%*s ", 18, str); - fprintf(fp, "%*s", 12, (model->metadata.output1[i].dequantize == 1 ? "Yes" : "No")); - fprintf(fp, "\n"); + if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) { + fprintf(fp, "%8u ", i); + fprintf(fp, "%*s ", 16, model->metadata.output1[i].output_name); + rte_ml_io_type_to_str(model->metadata.output1[i].output_type, str, STR_LEN); + fprintf(fp, "%*s ", 12, str); + rte_ml_io_type_to_str(model->metadata.output1[i].model_output_type, str, + STR_LEN); + fprintf(fp, "%*s ", 18, str); + fprintf(fp, "%*s", 12, + (model->metadata.output1[i].dequantize == 1 ? "Yes" : "No")); + fprintf(fp, "\n"); + } else { + j = i - MRVL_ML_NUM_INPUT_OUTPUT_1; + fprintf(fp, "%8u ", i); + fprintf(fp, "%*s ", 16, model->metadata.output2[j].output_name); + rte_ml_io_type_to_str(model->metadata.output2[j].output_type, str, STR_LEN); + fprintf(fp, "%*s ", 12, str); + rte_ml_io_type_to_str(model->metadata.output2[j].model_output_type, str, + STR_LEN); + fprintf(fp, "%*s ", 18, str); + fprintf(fp, "%*s", 12, + (model->metadata.output2[j].dequantize == 1 ? "Yes" : "No")); + fprintf(fp, "\n"); + } } fprintf(fp, "\n"); print_line(fp, LINE_LEN); @@ -1863,10 +1900,14 @@ cn10k_ml_io_quantize(struct rte_ml_dev *dev, uint16_t model_id, uint16_t nb_batc void *qbuffer) { struct cn10k_ml_model *model; + uint8_t model_input_type; uint8_t *lcl_dbuffer; uint8_t *lcl_qbuffer; + uint8_t input_type; uint32_t batch_id; + float qscale; uint32_t i; + uint32_t j; int ret; model = dev->data->models[model_id]; @@ -1882,28 +1923,38 @@ cn10k_ml_io_quantize(struct rte_ml_dev *dev, uint16_t model_id, uint16_t nb_batc next_batch: for (i = 0; i < model->metadata.model.num_input; i++) { - if (model->metadata.input1[i].input_type == - model->metadata.input1[i].model_input_type) { + if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) { + input_type = model->metadata.input1[i].input_type; + model_input_type = model->metadata.input1[i].model_input_type; + qscale = model->metadata.input1[i].qscale; + } else { + j = i - MRVL_ML_NUM_INPUT_OUTPUT_1; + input_type = model->metadata.input2[j].input_type; + model_input_type = model->metadata.input2[j].model_input_type; + qscale = model->metadata.input2[j].qscale; + } + + if (input_type == model_input_type) { rte_memcpy(lcl_qbuffer, lcl_dbuffer, model->addr.input[i].sz_d); } else { switch (model->metadata.input1[i].model_input_type) { case RTE_ML_IO_TYPE_INT8: - ret = rte_ml_io_float32_to_int8(model->metadata.input1[i].qscale, + ret = rte_ml_io_float32_to_int8(qscale, model->addr.input[i].nb_elements, lcl_dbuffer, lcl_qbuffer); break; case RTE_ML_IO_TYPE_UINT8: - ret = rte_ml_io_float32_to_uint8(model->metadata.input1[i].qscale, + ret = rte_ml_io_float32_to_uint8(qscale, model->addr.input[i].nb_elements, lcl_dbuffer, lcl_qbuffer); break; case RTE_ML_IO_TYPE_INT16: - ret = rte_ml_io_float32_to_int16(model->metadata.input1[i].qscale, + ret = rte_ml_io_float32_to_int16(qscale, model->addr.input[i].nb_elements, lcl_dbuffer, lcl_qbuffer); break; case RTE_ML_IO_TYPE_UINT16: - ret = rte_ml_io_float32_to_uint16(model->metadata.input1[i].qscale, + ret = rte_ml_io_float32_to_uint16(qscale, model->addr.input[i].nb_elements, lcl_dbuffer, lcl_qbuffer); break; @@ -1936,10 +1987,14 @@ cn10k_ml_io_dequantize(struct rte_ml_dev *dev, uint16_t model_id, uint16_t nb_ba void *qbuffer, void *dbuffer) { struct cn10k_ml_model *model; + uint8_t model_output_type; uint8_t *lcl_qbuffer; uint8_t *lcl_dbuffer; + uint8_t output_type; uint32_t batch_id; + float dscale; uint32_t i; + uint32_t j; int ret; model = dev->data->models[model_id]; @@ -1955,28 +2010,38 @@ cn10k_ml_io_dequantize(struct rte_ml_dev *dev, uint16_t model_id, uint16_t nb_ba next_batch: for (i = 0; i < model->metadata.model.num_output; i++) { - if (model->metadata.output1[i].output_type == - model->metadata.output1[i].model_output_type) { + if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) { + output_type = model->metadata.output1[i].output_type; + model_output_type = model->metadata.output1[i].model_output_type; + dscale = model->metadata.output1[i].dscale; + } else { + j = i - MRVL_ML_NUM_INPUT_OUTPUT_1; + output_type = model->metadata.output2[j].output_type; + model_output_type = model->metadata.output2[j].model_output_type; + dscale = model->metadata.output2[j].dscale; + } + + if (output_type == model_output_type) { rte_memcpy(lcl_dbuffer, lcl_qbuffer, model->addr.output[i].sz_q); } else { switch (model->metadata.output1[i].model_output_type) { case RTE_ML_IO_TYPE_INT8: - ret = rte_ml_io_int8_to_float32(model->metadata.output1[i].dscale, + ret = rte_ml_io_int8_to_float32(dscale, model->addr.output[i].nb_elements, lcl_qbuffer, lcl_dbuffer); break; case RTE_ML_IO_TYPE_UINT8: - ret = rte_ml_io_uint8_to_float32(model->metadata.output1[i].dscale, + ret = rte_ml_io_uint8_to_float32(dscale, model->addr.output[i].nb_elements, lcl_qbuffer, lcl_dbuffer); break; case RTE_ML_IO_TYPE_INT16: - ret = rte_ml_io_int16_to_float32(model->metadata.output1[i].dscale, + ret = rte_ml_io_int16_to_float32(dscale, model->addr.output[i].nb_elements, lcl_qbuffer, lcl_dbuffer); break; case RTE_ML_IO_TYPE_UINT16: - ret = rte_ml_io_uint16_to_float32(model->metadata.output1[i].dscale, + ret = rte_ml_io_uint16_to_float32(dscale, model->addr.output[i].nb_elements, lcl_qbuffer, lcl_dbuffer); break; -- 2.17.1