From patchwork Wed Sep 20 07:19:08 2023 Content-Type: text/plain; charset="utf-8" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit X-Patchwork-Submitter: Srikanth Yalavarthi X-Patchwork-Id: 131672 X-Patchwork-Delegate: thomas@monjalon.net Return-Path: X-Original-To: patchwork@inbox.dpdk.org Delivered-To: patchwork@inbox.dpdk.org Received: from mails.dpdk.org (mails.dpdk.org [217.70.189.124]) by inbox.dpdk.org (Postfix) with ESMTP id D34F8425CD; Wed, 20 Sep 2023 09:19:34 +0200 (CEST) Received: from mails.dpdk.org (localhost [127.0.0.1]) by mails.dpdk.org (Postfix) with ESMTP id A1D8740E25; Wed, 20 Sep 2023 09:19:25 +0200 (CEST) Received: from mx0b-0016f401.pphosted.com (mx0b-0016f401.pphosted.com [67.231.156.173]) by mails.dpdk.org (Postfix) with ESMTP id C3DDA406B8 for ; Wed, 20 Sep 2023 09:19:22 +0200 (CEST) Received: from pps.filterd (m0045851.ppops.net [127.0.0.1]) by mx0b-0016f401.pphosted.com (8.17.1.19/8.17.1.19) with ESMTP id 38JJRBJP009278 for ; Wed, 20 Sep 2023 00:19:22 -0700 DKIM-Signature: v=1; a=rsa-sha256; c=relaxed/relaxed; d=marvell.com; h=from : to : cc : subject : date : message-id : in-reply-to : references : mime-version : content-transfer-encoding : content-type; s=pfpt0220; bh=b7mVc2dmHfmjl+4a0uwMixnhx3kfKTc7Ql4QcAjYfZI=; b=ZHrFGjObLHySel1zLl3h6k+sNPKpaSdkErqlsf4Dv0XDi42YRRUzwUJ7pTA1QjSNQxH0 MKzAiDzHvgXiLBuHjw6rIbhFMtvxRBOiYLL9Fay1yK1N6lFStVoo6G4MuNpzIv5iOI5L YFFoEMP/A1vwOlf3miGq/xDi0YBajFTdi5J0T/ooWkZSOU6NzeJ5q4bmvVhoId2pbEd5 Iop9/YoJ7a9urlHCx215IlQmHrt9LQFYc28RLdFU8KiPY/U9iGT/AIKbBHEzujKZSbq2 xNgMSiHLRVM1Bd1tYaYSewG24chKRdWSUPBz7bCxSZ9P32KiKeVcEBzTGiWXl9FB/3HK 9g== Received: from dc5-exch01.marvell.com ([199.233.59.181]) by mx0b-0016f401.pphosted.com (PPS) with ESMTPS id 3t7htasxv7-2 (version=TLSv1.2 cipher=ECDHE-RSA-AES256-SHA384 bits=256 verify=NOT) for ; Wed, 20 Sep 2023 00:19:21 -0700 Received: from DC5-EXCH01.marvell.com (10.69.176.38) by DC5-EXCH01.marvell.com (10.69.176.38) with Microsoft SMTP Server (TLS) id 15.0.1497.48; Wed, 20 Sep 2023 00:19:19 -0700 Received: from maili.marvell.com (10.69.176.80) by DC5-EXCH01.marvell.com (10.69.176.38) with Microsoft SMTP Server id 15.0.1497.48 via Frontend Transport; Wed, 20 Sep 2023 00:19:19 -0700 Received: from ml-host-33.caveonetworks.com (unknown [10.110.143.233]) by maili.marvell.com (Postfix) with ESMTP id 13F215B6977; Wed, 20 Sep 2023 00:19:17 -0700 (PDT) From: Srikanth Yalavarthi To: Srikanth Yalavarthi CC: , , , Subject: [PATCH v2 1/3] mldev: add support for arbitrary shape dimensions Date: Wed, 20 Sep 2023 00:19:08 -0700 Message-ID: <20230920071910.10428-2-syalavarthi@marvell.com> X-Mailer: git-send-email 2.41.0 In-Reply-To: <20230920071910.10428-1-syalavarthi@marvell.com> References: <20230830155303.30380-1-syalavarthi@marvell.com> <20230920071910.10428-1-syalavarthi@marvell.com> MIME-Version: 1.0 X-Proofpoint-ORIG-GUID: EEDpmlHrWdSoocV-3sz2gLZwoEelvfyl X-Proofpoint-GUID: EEDpmlHrWdSoocV-3sz2gLZwoEelvfyl X-Proofpoint-Virus-Version: vendor=baseguard engine=ICAP:2.0.267,Aquarius:18.0.980,Hydra:6.0.601,FMLib:17.11.176.26 definitions=2023-09-20_02,2023-09-19_01,2023-05-22_02 X-BeenThere: dev@dpdk.org X-Mailman-Version: 2.1.29 Precedence: list List-Id: DPDK patches and discussions List-Unsubscribe: , List-Archive: List-Post: List-Help: List-Subscribe: , Errors-To: dev-bounces@dpdk.org Updated rte_ml_io_info to support shape of arbitrary number of dimensions. Dropped use of rte_ml_io_shape and rte_ml_io_format. Introduced new fields nb_elements and size in rte_ml_io_info. Updated drivers and app/mldev to support the changes. Signed-off-by: Srikanth Yalavarthi --- app/test-mldev/test_inference_common.c | 97 +++++--------------------- drivers/ml/cnxk/cn10k_ml_model.c | 78 +++++++++++++-------- drivers/ml/cnxk/cn10k_ml_model.h | 12 ++++ drivers/ml/cnxk/cn10k_ml_ops.c | 11 +-- lib/mldev/mldev_utils.c | 30 -------- lib/mldev/mldev_utils.h | 16 ----- lib/mldev/rte_mldev.h | 59 ++++------------ lib/mldev/version.map | 1 - 8 files changed, 94 insertions(+), 210 deletions(-) diff --git a/app/test-mldev/test_inference_common.c b/app/test-mldev/test_inference_common.c index 05b221401b..b40519b5e3 100644 --- a/app/test-mldev/test_inference_common.c +++ b/app/test-mldev/test_inference_common.c @@ -3,6 +3,7 @@ */ #include +#include #include #include @@ -18,11 +19,6 @@ #include "ml_common.h" #include "test_inference_common.h" -#define ML_TEST_READ_TYPE(buffer, type) (*((type *)buffer)) - -#define ML_TEST_CHECK_OUTPUT(output, reference, tolerance) \ - (((float)output - (float)reference) <= (((float)reference * tolerance) / 100.0)) - #define ML_OPEN_WRITE_GET_ERR(name, buffer, size, err) \ do { \ FILE *fp = fopen(name, "w+"); \ @@ -763,9 +759,9 @@ ml_inference_validation(struct ml_test *test, struct ml_request *req) { struct test_inference *t = ml_test_priv((struct ml_test *)test); struct ml_model *model; - uint32_t nb_elements; - uint8_t *reference; - uint8_t *output; + float *reference; + float *output; + float deviation; bool match; uint32_t i; uint32_t j; @@ -777,89 +773,30 @@ ml_inference_validation(struct ml_test *test, struct ml_request *req) match = (rte_hash_crc(model->output, model->out_dsize, 0) == rte_hash_crc(model->reference, model->out_dsize, 0)); } else { - output = model->output; - reference = model->reference; + output = (float *)model->output; + reference = (float *)model->reference; i = 0; next_output: - nb_elements = - model->info.output_info[i].shape.w * model->info.output_info[i].shape.x * - model->info.output_info[i].shape.y * model->info.output_info[i].shape.z; j = 0; next_element: match = false; - switch (model->info.output_info[i].dtype) { - case RTE_ML_IO_TYPE_INT8: - if (ML_TEST_CHECK_OUTPUT(ML_TEST_READ_TYPE(output, int8_t), - ML_TEST_READ_TYPE(reference, int8_t), - t->cmn.opt->tolerance)) - match = true; - - output += sizeof(int8_t); - reference += sizeof(int8_t); - break; - case RTE_ML_IO_TYPE_UINT8: - if (ML_TEST_CHECK_OUTPUT(ML_TEST_READ_TYPE(output, uint8_t), - ML_TEST_READ_TYPE(reference, uint8_t), - t->cmn.opt->tolerance)) - match = true; - - output += sizeof(float); - reference += sizeof(float); - break; - case RTE_ML_IO_TYPE_INT16: - if (ML_TEST_CHECK_OUTPUT(ML_TEST_READ_TYPE(output, int16_t), - ML_TEST_READ_TYPE(reference, int16_t), - t->cmn.opt->tolerance)) - match = true; - - output += sizeof(int16_t); - reference += sizeof(int16_t); - break; - case RTE_ML_IO_TYPE_UINT16: - if (ML_TEST_CHECK_OUTPUT(ML_TEST_READ_TYPE(output, uint16_t), - ML_TEST_READ_TYPE(reference, uint16_t), - t->cmn.opt->tolerance)) - match = true; - - output += sizeof(uint16_t); - reference += sizeof(uint16_t); - break; - case RTE_ML_IO_TYPE_INT32: - if (ML_TEST_CHECK_OUTPUT(ML_TEST_READ_TYPE(output, int32_t), - ML_TEST_READ_TYPE(reference, int32_t), - t->cmn.opt->tolerance)) - match = true; - - output += sizeof(int32_t); - reference += sizeof(int32_t); - break; - case RTE_ML_IO_TYPE_UINT32: - if (ML_TEST_CHECK_OUTPUT(ML_TEST_READ_TYPE(output, uint32_t), - ML_TEST_READ_TYPE(reference, uint32_t), - t->cmn.opt->tolerance)) - match = true; - - output += sizeof(uint32_t); - reference += sizeof(uint32_t); - break; - case RTE_ML_IO_TYPE_FP32: - if (ML_TEST_CHECK_OUTPUT(ML_TEST_READ_TYPE(output, float), - ML_TEST_READ_TYPE(reference, float), - t->cmn.opt->tolerance)) - match = true; - - output += sizeof(float); - reference += sizeof(float); - break; - default: /* other types, fp8, fp16, bfloat16 */ + deviation = + (*reference == 0 ? 0 : 100 * fabs(*output - *reference) / fabs(*reference)); + if (deviation <= t->cmn.opt->tolerance) match = true; - } + else + ml_err("id = %d, element = %d, output = %f, reference = %f, deviation = %f %%\n", + i, j, *output, *reference, deviation); + + output++; + reference++; if (!match) goto done; + j++; - if (j < nb_elements) + if (j < model->info.output_info[i].nb_elements) goto next_element; i++; diff --git a/drivers/ml/cnxk/cn10k_ml_model.c b/drivers/ml/cnxk/cn10k_ml_model.c index 92c47d39ba..26df8d9ff9 100644 --- a/drivers/ml/cnxk/cn10k_ml_model.c +++ b/drivers/ml/cnxk/cn10k_ml_model.c @@ -366,6 +366,12 @@ cn10k_ml_model_addr_update(struct cn10k_ml_model *model, uint8_t *buffer, uint8_ addr->total_input_sz_q = 0; for (i = 0; i < metadata->model.num_input; i++) { if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) { + addr->input[i].nb_dims = 4; + addr->input[i].shape[0] = metadata->input1[i].shape.w; + addr->input[i].shape[1] = metadata->input1[i].shape.x; + addr->input[i].shape[2] = metadata->input1[i].shape.y; + addr->input[i].shape[3] = metadata->input1[i].shape.z; + addr->input[i].nb_elements = metadata->input1[i].shape.w * metadata->input1[i].shape.x * metadata->input1[i].shape.y * metadata->input1[i].shape.z; @@ -386,6 +392,13 @@ cn10k_ml_model_addr_update(struct cn10k_ml_model *model, uint8_t *buffer, uint8_ addr->input[i].sz_q); } else { j = i - MRVL_ML_NUM_INPUT_OUTPUT_1; + + addr->input[i].nb_dims = 4; + addr->input[i].shape[0] = metadata->input2[j].shape.w; + addr->input[i].shape[1] = metadata->input2[j].shape.x; + addr->input[i].shape[2] = metadata->input2[j].shape.y; + addr->input[i].shape[3] = metadata->input2[j].shape.z; + addr->input[i].nb_elements = metadata->input2[j].shape.w * metadata->input2[j].shape.x * metadata->input2[j].shape.y * metadata->input2[j].shape.z; @@ -412,6 +425,8 @@ cn10k_ml_model_addr_update(struct cn10k_ml_model *model, uint8_t *buffer, uint8_ addr->total_output_sz_d = 0; for (i = 0; i < metadata->model.num_output; i++) { if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) { + addr->output[i].nb_dims = 1; + addr->output[i].shape[0] = metadata->output1[i].size; addr->output[i].nb_elements = metadata->output1[i].size; addr->output[i].sz_d = addr->output[i].nb_elements * @@ -426,6 +441,9 @@ cn10k_ml_model_addr_update(struct cn10k_ml_model *model, uint8_t *buffer, uint8_ model->model_id, i, addr->output[i].sz_d, addr->output[i].sz_q); } else { j = i - MRVL_ML_NUM_INPUT_OUTPUT_1; + + addr->output[i].nb_dims = 1; + addr->output[i].shape[0] = metadata->output2[j].size; addr->output[i].nb_elements = metadata->output2[j].size; addr->output[i].sz_d = addr->output[i].nb_elements * @@ -498,6 +516,7 @@ void cn10k_ml_model_info_set(struct rte_ml_dev *dev, struct cn10k_ml_model *model) { struct cn10k_ml_model_metadata *metadata; + struct cn10k_ml_model_addr *addr; struct rte_ml_model_info *info; struct rte_ml_io_info *output; struct rte_ml_io_info *input; @@ -508,6 +527,7 @@ cn10k_ml_model_info_set(struct rte_ml_dev *dev, struct cn10k_ml_model *model) info = PLT_PTR_CAST(model->info); input = PLT_PTR_ADD(info, sizeof(struct rte_ml_model_info)); output = PLT_PTR_ADD(input, metadata->model.num_input * sizeof(struct rte_ml_io_info)); + addr = &model->addr; /* Set model info */ memset(info, 0, sizeof(struct rte_ml_model_info)); @@ -529,24 +549,25 @@ cn10k_ml_model_info_set(struct rte_ml_dev *dev, struct cn10k_ml_model *model) if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) { rte_memcpy(input[i].name, metadata->input1[i].input_name, MRVL_ML_INPUT_NAME_LEN); - input[i].dtype = metadata->input1[i].input_type; - input[i].qtype = metadata->input1[i].model_input_type; - input[i].shape.format = metadata->input1[i].shape.format; - input[i].shape.w = metadata->input1[i].shape.w; - input[i].shape.x = metadata->input1[i].shape.x; - input[i].shape.y = metadata->input1[i].shape.y; - input[i].shape.z = metadata->input1[i].shape.z; + input[i].nb_dims = addr->input[i].nb_dims; + input[i].shape = addr->input[i].shape; + input[i].type = metadata->input1[i].model_input_type; + input[i].nb_elements = addr->input[i].nb_elements; + input[i].size = + addr->input[i].nb_elements * + rte_ml_io_type_size_get(metadata->input1[i].model_input_type); } else { j = i - MRVL_ML_NUM_INPUT_OUTPUT_1; + rte_memcpy(input[i].name, metadata->input2[j].input_name, MRVL_ML_INPUT_NAME_LEN); - input[i].dtype = metadata->input2[j].input_type; - input[i].qtype = metadata->input2[j].model_input_type; - input[i].shape.format = metadata->input2[j].shape.format; - input[i].shape.w = metadata->input2[j].shape.w; - input[i].shape.x = metadata->input2[j].shape.x; - input[i].shape.y = metadata->input2[j].shape.y; - input[i].shape.z = metadata->input2[j].shape.z; + input[i].nb_dims = addr->input[i].nb_dims; + input[i].shape = addr->input[i].shape; + input[i].type = metadata->input2[j].model_input_type; + input[i].nb_elements = addr->input[i].nb_elements; + input[i].size = + addr->input[i].nb_elements * + rte_ml_io_type_size_get(metadata->input2[j].model_input_type); } } @@ -555,24 +576,25 @@ cn10k_ml_model_info_set(struct rte_ml_dev *dev, struct cn10k_ml_model *model) if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) { rte_memcpy(output[i].name, metadata->output1[i].output_name, MRVL_ML_OUTPUT_NAME_LEN); - output[i].dtype = metadata->output1[i].output_type; - output[i].qtype = metadata->output1[i].model_output_type; - output[i].shape.format = RTE_ML_IO_FORMAT_1D; - output[i].shape.w = metadata->output1[i].size; - output[i].shape.x = 1; - output[i].shape.y = 1; - output[i].shape.z = 1; + output[i].nb_dims = addr->output[i].nb_dims; + output[i].shape = addr->output[i].shape; + output[i].type = metadata->output1[i].model_output_type; + output[i].nb_elements = addr->output[i].nb_elements; + output[i].size = + addr->output[i].nb_elements * + rte_ml_io_type_size_get(metadata->output1[i].model_output_type); } else { j = i - MRVL_ML_NUM_INPUT_OUTPUT_1; + rte_memcpy(output[i].name, metadata->output2[j].output_name, MRVL_ML_OUTPUT_NAME_LEN); - output[i].dtype = metadata->output2[j].output_type; - output[i].qtype = metadata->output2[j].model_output_type; - output[i].shape.format = RTE_ML_IO_FORMAT_1D; - output[i].shape.w = metadata->output2[j].size; - output[i].shape.x = 1; - output[i].shape.y = 1; - output[i].shape.z = 1; + output[i].nb_dims = addr->output[i].nb_dims; + output[i].shape = addr->output[i].shape; + output[i].type = metadata->output2[j].model_output_type; + output[i].nb_elements = addr->output[i].nb_elements; + output[i].size = + addr->output[i].nb_elements * + rte_ml_io_type_size_get(metadata->output2[j].model_output_type); } } } diff --git a/drivers/ml/cnxk/cn10k_ml_model.h b/drivers/ml/cnxk/cn10k_ml_model.h index 1f689363fc..4cc0744891 100644 --- a/drivers/ml/cnxk/cn10k_ml_model.h +++ b/drivers/ml/cnxk/cn10k_ml_model.h @@ -409,6 +409,12 @@ struct cn10k_ml_model_addr { /* Input address and size */ struct { + /* Number of dimensions in shape */ + uint32_t nb_dims; + + /* Shape of input */ + uint32_t shape[4]; + /* Number of elements */ uint32_t nb_elements; @@ -421,6 +427,12 @@ struct cn10k_ml_model_addr { /* Output address and size */ struct { + /* Number of dimensions in shape */ + uint32_t nb_dims; + + /* Shape of input */ + uint32_t shape[4]; + /* Number of elements */ uint32_t nb_elements; diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c index 656467d891..e3faab81ba 100644 --- a/drivers/ml/cnxk/cn10k_ml_ops.c +++ b/drivers/ml/cnxk/cn10k_ml_ops.c @@ -321,8 +321,8 @@ cn10k_ml_model_print(struct rte_ml_dev *dev, uint16_t model_id, FILE *fp) fprintf(fp, "\n"); print_line(fp, LINE_LEN); - fprintf(fp, "%8s %16s %12s %18s %12s %14s\n", "input", "input_name", "input_type", - "model_input_type", "quantize", "format"); + fprintf(fp, "%8s %16s %12s %18s %12s\n", "input", "input_name", "input_type", + "model_input_type", "quantize"); print_line(fp, LINE_LEN); for (i = 0; i < model->metadata.model.num_input; i++) { if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) { @@ -335,12 +335,10 @@ cn10k_ml_model_print(struct rte_ml_dev *dev, uint16_t model_id, FILE *fp) fprintf(fp, "%*s ", 18, str); fprintf(fp, "%*s", 12, (model->metadata.input1[i].quantize == 1 ? "Yes" : "No")); - rte_ml_io_format_to_str(model->metadata.input1[i].shape.format, str, - STR_LEN); - fprintf(fp, "%*s", 16, str); fprintf(fp, "\n"); } else { j = i - MRVL_ML_NUM_INPUT_OUTPUT_1; + fprintf(fp, "%8u ", i); fprintf(fp, "%*s ", 16, model->metadata.input2[j].input_name); rte_ml_io_type_to_str(model->metadata.input2[j].input_type, str, STR_LEN); @@ -350,9 +348,6 @@ cn10k_ml_model_print(struct rte_ml_dev *dev, uint16_t model_id, FILE *fp) fprintf(fp, "%*s ", 18, str); fprintf(fp, "%*s", 12, (model->metadata.input2[j].quantize == 1 ? "Yes" : "No")); - rte_ml_io_format_to_str(model->metadata.input2[j].shape.format, str, - STR_LEN); - fprintf(fp, "%*s", 16, str); fprintf(fp, "\n"); } } diff --git a/lib/mldev/mldev_utils.c b/lib/mldev/mldev_utils.c index d2442b123b..ccd2c39ca8 100644 --- a/lib/mldev/mldev_utils.c +++ b/lib/mldev/mldev_utils.c @@ -86,33 +86,3 @@ rte_ml_io_type_to_str(enum rte_ml_io_type type, char *str, int len) rte_strlcpy(str, "invalid", len); } } - -void -rte_ml_io_format_to_str(enum rte_ml_io_format format, char *str, int len) -{ - switch (format) { - case RTE_ML_IO_FORMAT_NCHW: - rte_strlcpy(str, "NCHW", len); - break; - case RTE_ML_IO_FORMAT_NHWC: - rte_strlcpy(str, "NHWC", len); - break; - case RTE_ML_IO_FORMAT_CHWN: - rte_strlcpy(str, "CHWN", len); - break; - case RTE_ML_IO_FORMAT_3D: - rte_strlcpy(str, "3D", len); - break; - case RTE_ML_IO_FORMAT_2D: - rte_strlcpy(str, "Matrix", len); - break; - case RTE_ML_IO_FORMAT_1D: - rte_strlcpy(str, "Vector", len); - break; - case RTE_ML_IO_FORMAT_SCALAR: - rte_strlcpy(str, "Scalar", len); - break; - default: - rte_strlcpy(str, "invalid", len); - } -} diff --git a/lib/mldev/mldev_utils.h b/lib/mldev/mldev_utils.h index 5bc8020453..220afb42f0 100644 --- a/lib/mldev/mldev_utils.h +++ b/lib/mldev/mldev_utils.h @@ -52,22 +52,6 @@ __rte_internal void rte_ml_io_type_to_str(enum rte_ml_io_type type, char *str, int len); -/** - * @internal - * - * Get the name of an ML IO format. - * - * @param[in] type - * Enumeration of ML IO format. - * @param[in] str - * Address of character array. - * @param[in] len - * Length of character array. - */ -__rte_internal -void -rte_ml_io_format_to_str(enum rte_ml_io_format format, char *str, int len); - /** * @internal * diff --git a/lib/mldev/rte_mldev.h b/lib/mldev/rte_mldev.h index fc3525c1ab..6204df0930 100644 --- a/lib/mldev/rte_mldev.h +++ b/lib/mldev/rte_mldev.h @@ -863,47 +863,6 @@ enum rte_ml_io_type { /**< 16-bit brain floating point number. */ }; -/** - * Input and output format. This is used to represent the encoding type of multi-dimensional - * used by ML models. - */ -enum rte_ml_io_format { - RTE_ML_IO_FORMAT_NCHW = 1, - /**< Batch size (N) x channels (C) x height (H) x width (W) */ - RTE_ML_IO_FORMAT_NHWC, - /**< Batch size (N) x height (H) x width (W) x channels (C) */ - RTE_ML_IO_FORMAT_CHWN, - /**< Channels (C) x height (H) x width (W) x batch size (N) */ - RTE_ML_IO_FORMAT_3D, - /**< Format to represent a 3 dimensional data */ - RTE_ML_IO_FORMAT_2D, - /**< Format to represent matrix data */ - RTE_ML_IO_FORMAT_1D, - /**< Format to represent vector data */ - RTE_ML_IO_FORMAT_SCALAR, - /**< Format to represent scalar data */ -}; - -/** - * Input and output shape. This structure represents the encoding format and dimensions - * of the tensor or vector. - * - * The data can be a 4D / 3D tensor, matrix, vector or a scalar. Number of dimensions used - * for the data would depend on the format. Unused dimensions to be set to 1. - */ -struct rte_ml_io_shape { - enum rte_ml_io_format format; - /**< Format of the data */ - uint32_t w; - /**< First dimension */ - uint32_t x; - /**< Second dimension */ - uint32_t y; - /**< Third dimension */ - uint32_t z; - /**< Fourth dimension */ -}; - /** Input and output data information structure * * Specifies the type and shape of input and output data. @@ -911,12 +870,18 @@ struct rte_ml_io_shape { struct rte_ml_io_info { char name[RTE_ML_STR_MAX]; /**< Name of data */ - struct rte_ml_io_shape shape; - /**< Shape of data */ - enum rte_ml_io_type qtype; - /**< Type of quantized data */ - enum rte_ml_io_type dtype; - /**< Type of de-quantized data */ + uint32_t nb_dims; + /**< Number of dimensions in shape */ + uint32_t *shape; + /**< Shape of the tensor */ + enum rte_ml_io_type type; + /**< Type of data + * @see enum rte_ml_io_type + */ + uint64_t nb_elements; + /** Number of elements in tensor */ + uint64_t size; + /** Size of tensor in bytes */ }; /** Model information structure */ diff --git a/lib/mldev/version.map b/lib/mldev/version.map index 0706b565be..40ff27f4b9 100644 --- a/lib/mldev/version.map +++ b/lib/mldev/version.map @@ -51,7 +51,6 @@ INTERNAL { rte_ml_io_type_size_get; rte_ml_io_type_to_str; - rte_ml_io_format_to_str; rte_ml_io_float32_to_int8; rte_ml_io_int8_to_float32; rte_ml_io_float32_to_uint8;