From mboxrd@z Thu Jan 1 00:00:00 1970 Return-Path: Received: from mails.dpdk.org (mails.dpdk.org [217.70.189.124]) by inbox.dpdk.org (Postfix) with ESMTP id 0D5CD41B9D; Wed, 1 Feb 2023 10:25:35 +0100 (CET) Received: from mails.dpdk.org (localhost [127.0.0.1]) by mails.dpdk.org (Postfix) with ESMTP id 32E4C42FB4; Wed, 1 Feb 2023 10:23:42 +0100 (CET) Received: from mx0b-0016f401.pphosted.com (mx0a-0016f401.pphosted.com [67.231.148.174]) by mails.dpdk.org (Postfix) with ESMTP id D947642D59 for ; Wed, 1 Feb 2023 10:23:23 +0100 (CET) Received: from pps.filterd (m0045849.ppops.net [127.0.0.1]) by mx0a-0016f401.pphosted.com (8.17.1.19/8.17.1.19) with ESMTP id 3116LRY1024189 for ; Wed, 1 Feb 2023 01:23:23 -0800 DKIM-Signature: v=1; a=rsa-sha256; c=relaxed/relaxed; d=marvell.com; h=from : to : cc : subject : date : message-id : in-reply-to : references : mime-version : content-type; s=pfpt0220; bh=MDfqn4DddTvP8n715qgR/G2bzlFar5Qg7DKXqqKoQzg=; b=e2oe1nTttp3EyBwQH76OTt1ntxte1f+3/l0Mmbz0ITp3jPhlPEuT8QR9ymJjjIGka8ho qhWtkv4JilVZ86SwelhqzBNf/hsc36kCki8Fk2378HOxRdIqRKIz9aJw0HqiblzSFs4T HtUCwRt5PEOCQDlK/3EEPlm7Zr2jCl+ZMti0/p8lCW3yLIWhCOJxByRsSWwxJIwd/JUQ sFclrImvEGVb4iu2JWxDDk8uOC9TX0LDPO3DlkkuuZ0PaLmFpkxpl0Of3ip6HFDBowC/ 0ra8Uh020EFaYoKISrNjvtXici3kIRvabQDppH7dT9OBJ/LqjPJHSlEicKd5VTsFOiae Nw== Received: from dc5-exch02.marvell.com ([199.233.59.182]) by mx0a-0016f401.pphosted.com (PPS) with ESMTPS id 3nfjr8rgv6-3 (version=TLSv1.2 cipher=ECDHE-RSA-AES256-SHA384 bits=256 verify=NOT) for ; Wed, 01 Feb 2023 01:23:22 -0800 Received: from DC5-EXCH01.marvell.com (10.69.176.38) by DC5-EXCH02.marvell.com (10.69.176.39) with Microsoft SMTP Server (TLS) id 15.0.1497.42; Wed, 1 Feb 2023 01:23:19 -0800 Received: from maili.marvell.com (10.69.176.80) by DC5-EXCH01.marvell.com (10.69.176.38) with Microsoft SMTP Server id 15.0.1497.42 via Frontend Transport; Wed, 1 Feb 2023 01:23:19 -0800 Received: from ml-host-33.caveonetworks.com (unknown [10.110.143.233]) by maili.marvell.com (Postfix) with ESMTP id 9E0BF3F7124; Wed, 1 Feb 2023 01:23:16 -0800 (PST) From: Srikanth Yalavarthi To: Srikanth Yalavarthi CC: , , , Subject: [PATCH v4 12/39] ml/cnxk: enable validity checks for model metadata Date: Wed, 1 Feb 2023 01:22:43 -0800 Message-ID: <20230201092310.23252-13-syalavarthi@marvell.com> X-Mailer: git-send-email 2.17.1 In-Reply-To: <20230201092310.23252-1-syalavarthi@marvell.com> References: <20221208200220.20267-1-syalavarthi@marvell.com> <20230201092310.23252-1-syalavarthi@marvell.com> MIME-Version: 1.0 Content-Type: text/plain X-Proofpoint-GUID: zLfl9KcKEvpFXkNj-qoX0yisWyQUWa2C X-Proofpoint-ORIG-GUID: zLfl9KcKEvpFXkNj-qoX0yisWyQUWa2C X-Proofpoint-Virus-Version: vendor=baseguard engine=ICAP:2.0.219,Aquarius:18.0.930,Hydra:6.0.562,FMLib:17.11.122.1 definitions=2023-02-01_03,2023-01-31_01,2022-06-22_01 X-BeenThere: dev@dpdk.org X-Mailman-Version: 2.1.29 Precedence: list List-Id: DPDK patches and discussions List-Unsubscribe: , List-Archive: List-Post: List-Help: List-Subscribe: , Errors-To: dev-bounces@dpdk.org Added model metadata structure and enabled metadata check during model load. Remap cnxk IO types with RTE IO types. Store and update model metadata in model structure. Signed-off-by: Srikanth Yalavarthi --- drivers/ml/cnxk/cn10k_ml_model.c | 211 +++++++++++++++++++++ drivers/ml/cnxk/cn10k_ml_model.h | 312 +++++++++++++++++++++++++++++++ drivers/ml/cnxk/cn10k_ml_ops.c | 14 +- drivers/ml/cnxk/meson.build | 2 +- 4 files changed, 537 insertions(+), 2 deletions(-) diff --git a/drivers/ml/cnxk/cn10k_ml_model.c b/drivers/ml/cnxk/cn10k_ml_model.c index 39ed707396..0fefab9daa 100644 --- a/drivers/ml/cnxk/cn10k_ml_model.c +++ b/drivers/ml/cnxk/cn10k_ml_model.c @@ -2,4 +2,215 @@ * Copyright (c) 2022 Marvell. */ +#include + +#include + +#include "cn10k_ml_dev.h" #include "cn10k_ml_model.h" + +static enum rte_ml_io_type +cn10k_ml_io_type_map(uint8_t type) +{ + switch (type) { + case 1: + return RTE_ML_IO_TYPE_INT8; + case 2: + return RTE_ML_IO_TYPE_UINT8; + case 3: + return RTE_ML_IO_TYPE_INT16; + case 4: + return RTE_ML_IO_TYPE_UINT16; + case 5: + return RTE_ML_IO_TYPE_INT32; + case 6: + return RTE_ML_IO_TYPE_UINT32; + case 7: + return RTE_ML_IO_TYPE_FP16; + case 8: + return RTE_ML_IO_TYPE_FP32; + } + + return RTE_ML_IO_TYPE_UNKNOWN; +} + +int +cn10k_ml_model_metadata_check(uint8_t *buffer, uint64_t size) +{ + struct cn10k_ml_model_metadata *metadata; + uint32_t payload_crc32c; + uint32_t header_crc32c; + uint8_t version[4]; + uint8_t i; + + metadata = (struct cn10k_ml_model_metadata *)buffer; + + /* Header CRC check */ + if (metadata->metadata_header.header_crc32c != 0) { + header_crc32c = rte_hash_crc( + buffer, sizeof(metadata->metadata_header) - sizeof(uint32_t), 0); + + if (header_crc32c != metadata->metadata_header.header_crc32c) { + plt_err("Invalid model, Header CRC mismatch"); + return -EINVAL; + } + } + + /* Payload CRC check */ + if (metadata->metadata_header.payload_crc32c != 0) { + payload_crc32c = rte_hash_crc(buffer + sizeof(metadata->metadata_header), + size - sizeof(metadata->metadata_header), 0); + + if (payload_crc32c != metadata->metadata_header.payload_crc32c) { + plt_err("Invalid model, Payload CRC mismatch"); + return -EINVAL; + } + } + + /* Model magic string */ + if (strncmp((char *)metadata->metadata_header.magic, MRVL_ML_MODEL_MAGIC_STRING, 4) != 0) { + plt_err("Invalid model, magic = %s", metadata->metadata_header.magic); + return -EINVAL; + } + + /* Target architecture */ + if (metadata->metadata_header.target_architecture != MRVL_ML_MODEL_TARGET_ARCH) { + plt_err("Model target architecture (%u) not supported", + metadata->metadata_header.target_architecture); + return -ENOTSUP; + } + + /* Header version */ + rte_memcpy(version, metadata->metadata_header.version, 4 * sizeof(uint8_t)); + if (version[0] * 1000 + version[1] * 100 < MRVL_ML_MODEL_VERSION) { + plt_err("Metadata version = %u.%u.%u.%u (< %u.%u.%u.%u) not supported", version[0], + version[1], version[2], version[3], (MRVL_ML_MODEL_VERSION / 1000) % 10, + (MRVL_ML_MODEL_VERSION / 100) % 10, (MRVL_ML_MODEL_VERSION / 10) % 10, + MRVL_ML_MODEL_VERSION % 10); + return -ENOTSUP; + } + + /* Init section */ + if (metadata->init_model.file_size == 0) { + plt_err("Invalid metadata, init_model.file_size = %u", + metadata->init_model.file_size); + return -EINVAL; + } + + /* Main section */ + if (metadata->main_model.file_size == 0) { + plt_err("Invalid metadata, main_model.file_size = %u", + metadata->main_model.file_size); + return -EINVAL; + } + + /* Finish section */ + if (metadata->finish_model.file_size == 0) { + plt_err("Invalid metadata, finish_model.file_size = %u", + metadata->finish_model.file_size); + return -EINVAL; + } + + /* Weights and Bias */ + if (metadata->weights_bias.file_size == 0) { + plt_err("Invalid metadata, weights_bias.file_size = %u", + metadata->weights_bias.file_size); + return -EINVAL; + } + + if (metadata->weights_bias.relocatable != 1) { + plt_err("Model not supported, non-relocatable weights and bias"); + return -ENOTSUP; + } + + /* Check input count */ + if (metadata->model.num_input > MRVL_ML_INPUT_OUTPUT_SIZE) { + plt_err("Invalid metadata, num_input = %d (> %d)", metadata->model.num_input, + MRVL_ML_INPUT_OUTPUT_SIZE); + return -EINVAL; + } + + /* Check output count */ + if (metadata->model.num_output > MRVL_ML_INPUT_OUTPUT_SIZE) { + plt_err("Invalid metadata, num_output = %d (> %d)", metadata->model.num_output, + MRVL_ML_INPUT_OUTPUT_SIZE); + return -EINVAL; + } + + /* Inputs */ + for (i = 0; i < metadata->model.num_input; i++) { + if (rte_ml_io_type_size_get(cn10k_ml_io_type_map(metadata->input[i].input_type)) <= + 0) { + plt_err("Invalid metadata, input[%u] : input_type = %u", i, + metadata->input[i].input_type); + return -EINVAL; + } + + if (rte_ml_io_type_size_get( + cn10k_ml_io_type_map(metadata->input[i].model_input_type)) <= 0) { + plt_err("Invalid metadata, input[%u] : model_input_type = %u", i, + metadata->input[i].model_input_type); + return -EINVAL; + } + + if (metadata->input[i].relocatable != 1) { + plt_err("Model not supported, non-relocatable input: %u", i); + return -ENOTSUP; + } + } + + /* Outputs */ + for (i = 0; i < metadata->model.num_output; i++) { + if (rte_ml_io_type_size_get( + cn10k_ml_io_type_map(metadata->output[i].output_type)) <= 0) { + plt_err("Invalid metadata, output[%u] : output_type = %u", i, + metadata->output[i].output_type); + return -EINVAL; + } + + if (rte_ml_io_type_size_get( + cn10k_ml_io_type_map(metadata->output[i].model_output_type)) <= 0) { + plt_err("Invalid metadata, output[%u] : model_output_type = %u", i, + metadata->output[i].model_output_type); + return -EINVAL; + } + + if (metadata->output[i].relocatable != 1) { + plt_err("Model not supported, non-relocatable output: %u", i); + return -ENOTSUP; + } + } + + return 0; +} + +void +cn10k_ml_model_metadata_update(struct cn10k_ml_model_metadata *metadata) +{ + uint8_t i; + + for (i = 0; i < metadata->model.num_input; i++) { + metadata->input[i].input_type = cn10k_ml_io_type_map(metadata->input[i].input_type); + metadata->input[i].model_input_type = + cn10k_ml_io_type_map(metadata->input[i].model_input_type); + + if (metadata->input[i].shape.w == 0) + metadata->input[i].shape.w = 1; + + if (metadata->input[i].shape.x == 0) + metadata->input[i].shape.x = 1; + + if (metadata->input[i].shape.y == 0) + metadata->input[i].shape.y = 1; + + if (metadata->input[i].shape.z == 0) + metadata->input[i].shape.z = 1; + } + + for (i = 0; i < metadata->model.num_output; i++) { + metadata->output[i].output_type = + cn10k_ml_io_type_map(metadata->output[i].output_type); + metadata->output[i].model_output_type = + cn10k_ml_io_type_map(metadata->output[i].model_output_type); + } +} diff --git a/drivers/ml/cnxk/cn10k_ml_model.h b/drivers/ml/cnxk/cn10k_ml_model.h index 912fdb9758..e25d6780e9 100644 --- a/drivers/ml/cnxk/cn10k_ml_model.h +++ b/drivers/ml/cnxk/cn10k_ml_model.h @@ -19,6 +19,309 @@ enum cn10k_ml_model_state { ML_CN10K_MODEL_STATE_UNKNOWN, }; +/* Model Metadata : v 2.1.0.2 */ +#define MRVL_ML_MODEL_MAGIC_STRING "MRVL" +#define MRVL_ML_MODEL_TARGET_ARCH 128 +#define MRVL_ML_MODEL_VERSION 2100 +#define MRVL_ML_MODEL_NAME_LEN 64 +#define MRVL_ML_INPUT_NAME_LEN 16 +#define MRVL_ML_OUTPUT_NAME_LEN 16 +#define MRVL_ML_INPUT_OUTPUT_SIZE 8 + +/* Model file metadata structure */ +struct cn10k_ml_model_metadata { + /* Header (256-byte) */ + struct { + /* Magic string ('M', 'R', 'V', 'L') */ + uint8_t magic[4]; + + /* Metadata version */ + uint8_t version[4]; + + /* Metadata size */ + uint32_t metadata_size; + + /* Unique ID */ + uint8_t uuid[128]; + + /* Model target architecture + * 0 = Undefined + * 1 = M1K + * 128 = MLIP + * 256 = Experimental + */ + uint32_t target_architecture; + uint8_t reserved[104]; + + /* CRC of data after metadata_header (i.e. after first 256 bytes) */ + uint32_t payload_crc32c; + + /* CRC of first 252 bytes of metadata_header, after payload_crc calculation */ + uint32_t header_crc32c; + } metadata_header; + + /* Model information (256-byte) */ + struct { + /* Model name string */ + uint8_t name[MRVL_ML_MODEL_NAME_LEN]; + + /* Model version info (xx.xx.xx.xx) */ + uint8_t version[4]; + + /* Model code size (Init + Main + Finish) */ + uint32_t code_size; + + /* Model data size (Weights and Bias) */ + uint32_t data_size; + + /* OCM start offset, set to ocm_wb_range_start */ + uint32_t ocm_start; + + /* OCM start offset, set to max OCM size */ + uint32_t ocm_end; + + /* Relocatable flag (always yes) + * 0 = Not relocatable + * 1 = Relocatable + */ + uint8_t ocm_relocatable; + + /* Tile relocatable flag (always yes) + * 0 = Not relocatable + * 1 = Relocatable + */ + uint8_t tile_relocatable; + + /* Start tile (Always 0) */ + uint8_t tile_start; + + /* End tile (num_tiles - 1) */ + uint8_t tile_end; + + /* Inference batch size */ + uint8_t batch_size; + + /* Number of input tensors (Max 8) */ + uint8_t num_input; + + /* Number of output tensors (Max 8) */ + uint8_t num_output; + uint8_t reserved1; + + /* Total input size in bytes */ + uint32_t input_size; + + /* Total output size in bytes */ + uint32_t output_size; + + /* Table size in bytes */ + uint32_t table_size; + + /* Number of layers in the network */ + uint32_t num_layers; + uint32_t reserved2; + + /* Floor of absolute OCM region */ + uint64_t ocm_tmp_range_floor; + + /* Relative OCM start address of WB data block */ + uint64_t ocm_wb_range_start; + + /* Relative OCM end address of WB data block */ + uint64_t ocm_wb_range_end; + + /* Relative DDR start address of WB data block */ + uint64_t ddr_wb_range_start; + + /* Relative DDR end address of all outputs */ + uint64_t ddr_wb_range_end; + + /* Relative DDR start address of all inputs */ + uint64_t ddr_input_range_start; + + /* Relative DDR end address of all inputs */ + uint64_t ddr_input_range_end; + + /* Relative DDR start address of all outputs */ + uint64_t ddr_output_range_start; + + /* Relative DDR end address of all outputs */ + uint64_t ddr_output_range_end; + + /* Compiler version */ + uint8_t compiler_version[8]; + + /* CDK version */ + uint8_t cdk_version[4]; + + /* Lower batch optimization support + * 0 - No, + * 1 - Yes + */ + uint8_t supports_lower_batch_size_optimization; + uint8_t reserved3[59]; + } model; + + /* Init section (64-byte) */ + struct { + uint32_t file_offset; + uint32_t file_size; + uint8_t reserved[56]; + } init_model; + + /* Main section (64-byte) */ + struct { + uint32_t file_offset; + uint32_t file_size; + uint8_t reserved[56]; + } main_model; + + /* Finish section (64-byte) */ + struct { + uint32_t file_offset; + uint32_t file_size; + uint8_t reserved[56]; + } finish_model; + + uint8_t reserved1[512]; /* End of 2k bytes */ + + /* Weights and Bias (64-byte) */ + struct { + /* Memory offset, set to ddr_wb_range_start */ + uint64_t mem_offset; + uint32_t file_offset; + uint32_t file_size; + + /* Relocatable flag for WB + * 1 = Relocatable + * 2 = Not relocatable + */ + uint8_t relocatable; + uint8_t reserved[47]; + } weights_bias; + + /* Input (512-byte, 64-byte per input) provisioned for 8 inputs */ + struct { + /* DDR offset (in OCM absolute addresses for input) */ + uint64_t mem_offset; + + /* Relocatable flag + * 1 = Relocatable + * 2 = Not relocatable + */ + uint8_t relocatable; + + /* Input quantization + * 1 = Requires quantization + * 2 = Pre-quantized + */ + uint8_t quantize; + + /* Type of incoming input + * 1 = INT8, 2 = UINT8, 3 = INT16, 4 = UINT16, + * 5 = INT32, 6 = UINT32, 7 = FP16, 8 = FP32 + */ + uint8_t input_type; + + /* Type of input required by model + * 1 = INT8, 2 = UINT8, 3 = INT16, 4 = UINT16, + * 5 = INT32, 6 = UINT32, 7 = FP16, 8 = FP32 + */ + uint8_t model_input_type; + + /* float_32 qscale value + * quantized = non-quantized * qscale + */ + float qscale; + + /* Input shape */ + struct { + /* Input format + * 1 = NCHW + * 2 = NHWC + */ + uint8_t format; + uint8_t reserved[3]; + uint32_t w; + uint32_t x; + uint32_t y; + uint32_t z; + } shape; + uint8_t reserved[4]; + + /* Name of input */ + uint8_t input_name[MRVL_ML_INPUT_NAME_LEN]; + + /* DDR range end + * new = mem_offset + size_bytes - 1 + */ + uint64_t ddr_range_end; + } input[MRVL_ML_INPUT_OUTPUT_SIZE]; + + /* Output (512 byte, 64-byte per input) provisioned for 8 outputs */ + struct { + /* DDR offset in OCM absolute addresses for output */ + uint64_t mem_offset; + + /* Relocatable flag + * 1 = Relocatable + * 2 = Not relocatable + */ + uint8_t relocatable; + + /* Output dequantization + * 1 = De-quantization required + * 2 = De-quantization not required + */ + uint8_t dequantize; + + /* Type of outgoing output + * 1 = INT8, 2 = UINT8, 3 = INT16, 4 = UINT16 + * 5 = INT32, 6 = UINT32, 7 = FP16, 8 = FP32 + */ + uint8_t output_type; + + /* Type of output produced by model + * 1 = INT8, 2 = UINT8, 3 = INT16, 4 = UINT16 + * 5 = INT32, 6 = UINT32, 7 = FP16, 8 = FP32 + */ + uint8_t model_output_type; + + /* float_32 dscale value + * dequantized = quantized * dscale + */ + float dscale; + + /* Number of items in the output */ + uint32_t size; + uint8_t reserved[20]; + + /* DDR range end + * new = mem_offset + size_bytes - 1 + */ + uint64_t ddr_range_end; + uint8_t output_name[MRVL_ML_OUTPUT_NAME_LEN]; + } output[MRVL_ML_INPUT_OUTPUT_SIZE]; + + uint8_t reserved2[1792]; + + /* Model data */ + struct { + uint8_t reserved1[4068]; + + /* Beta: xx.xx.xx.xx, + * Later: YYYYMM.xx.xx + */ + uint8_t compiler_version[8]; + + /* M1K CDK version (xx.xx.xx.xx) */ + uint8_t m1k_cdk_version[4]; + } data; + + /* Hidden 16 bytes of magic code */ + uint8_t reserved3[16]; +}; + /* Model Object */ struct cn10k_ml_model { /* Device reference */ @@ -30,6 +333,12 @@ struct cn10k_ml_model { /* ID */ int16_t model_id; + /* Batch size */ + uint32_t batch_size; + + /* Metadata */ + struct cn10k_ml_model_metadata metadata; + /* Spinlock, used to update model state */ plt_spinlock_t lock; @@ -37,4 +346,7 @@ struct cn10k_ml_model { enum cn10k_ml_model_state state; }; +int cn10k_ml_model_metadata_check(uint8_t *buffer, uint64_t size); +void cn10k_ml_model_metadata_update(struct cn10k_ml_model_metadata *metadata); + #endif /* _CN10K_ML_MODEL_H_ */ diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c index d177d0e3e4..f7c1d43aee 100644 --- a/drivers/ml/cnxk/cn10k_ml_ops.c +++ b/drivers/ml/cnxk/cn10k_ml_ops.c @@ -416,8 +416,11 @@ cn10k_ml_model_load(struct rte_ml_dev *dev, struct rte_ml_model_params *params, uint64_t mz_size; uint16_t idx; bool found; + int ret; - PLT_SET_USED(params); + ret = cn10k_ml_model_metadata_check(params->addr, params->size); + if (ret != 0) + return ret; mldev = dev->data->dev_private; @@ -450,6 +453,15 @@ cn10k_ml_model_load(struct rte_ml_dev *dev, struct rte_ml_model_params *params, model->mldev = mldev; model->model_id = idx; + rte_memcpy(&model->metadata, params->addr, sizeof(struct cn10k_ml_model_metadata)); + cn10k_ml_model_metadata_update(&model->metadata); + + /* Enable support for batch_size of 256 */ + if (model->metadata.model.batch_size == 0) + model->batch_size = 256; + else + model->batch_size = model->metadata.model.batch_size; + plt_spinlock_init(&model->lock); model->state = ML_CN10K_MODEL_STATE_LOADED; dev->data->models[idx] = model; diff --git a/drivers/ml/cnxk/meson.build b/drivers/ml/cnxk/meson.build index bf7a9c0225..799e8f2470 100644 --- a/drivers/ml/cnxk/meson.build +++ b/drivers/ml/cnxk/meson.build @@ -19,7 +19,7 @@ sources = files( 'cn10k_ml_model.c', ) -deps += ['mldev', 'common_cnxk', 'kvargs'] +deps += ['mldev', 'common_cnxk', 'kvargs', 'hash'] if get_option('buildtype').contains('debug') cflags += [ '-DCNXK_ML_DEV_DEBUG' ] -- 2.17.1