From: Anup Prabhu <aprabhu@marvell.com>
Enabled check for OCM size requirement for multi-layer
TVM model. Compute OCM scratch and WB requirement for
all layers during the load stage.
Signed-off-by: Anup Prabhu <aprabhu@marvell.com>
---
drivers/ml/cnxk/cnxk_ml_ops.c | 60 +++++++++++++++++++++++++++++++++++
1 file changed, 60 insertions(+)
@@ -1023,8 +1023,12 @@ cnxk_ml_model_load(struct rte_ml_dev *dev, struct rte_ml_model_params *params, u
char str[RTE_MEMZONE_NAMESIZE];
const struct plt_memzone *mz;
+ uint16_t max_scratch_pages;
+ struct cn10k_ml_ocm *ocm;
uint64_t model_info_size;
+ uint16_t total_wb_pages;
uint16_t lcl_model_id;
+ uint16_t layer_id;
uint64_t mz_size;
bool found;
int ret;
@@ -1086,6 +1090,62 @@ cnxk_ml_model_load(struct rte_ml_dev *dev, struct rte_ml_model_params *params, u
if (ret != 0)
goto error;
+ max_scratch_pages = 0;
+ total_wb_pages = 0;
+ layer_id = 0;
+
+ ocm = &cnxk_mldev->cn10k_mldev.ocm;
+
+ if (model->type == ML_CNXK_MODEL_TYPE_GLOW) {
+ total_wb_pages = total_wb_pages + model->layer[layer_id].glow.ocm_map.wb_pages;
+ max_scratch_pages = PLT_MAX(max_scratch_pages,
+ model->layer[layer_id].glow.ocm_map.scratch_pages);
+#ifdef RTE_MLDEV_CNXK_ENABLE_MVTVM
+ } else {
+ for (layer_id = 0; layer_id < model->mvtvm.metadata.model.nb_layers; layer_id++) {
+ if (model->layer[layer_id].type == ML_CNXK_LAYER_TYPE_MRVL) {
+ total_wb_pages = total_wb_pages +
+ model->layer[layer_id].glow.ocm_map.wb_pages;
+ max_scratch_pages =
+ PLT_MAX(max_scratch_pages,
+ model->layer[layer_id].glow.ocm_map.scratch_pages);
+ }
+ }
+#endif
+ }
+
+ if ((total_wb_pages + max_scratch_pages) > ocm->num_pages) {
+ plt_err("model_id = %u: total_wb_pages (%u) + scratch_pages (%u) > %u\n",
+ lcl_model_id, total_wb_pages, max_scratch_pages, ocm->num_pages);
+
+ if (model->type == ML_CNXK_MODEL_TYPE_GLOW) {
+ plt_ml_dbg("layer_id = %u: wb_pages = %u, scratch_pages = %u\n", layer_id,
+ model->layer[layer_id].glow.ocm_map.wb_pages,
+ model->layer[layer_id].glow.ocm_map.scratch_pages);
+#ifdef RTE_MLDEV_CNXK_ENABLE_MVTVM
+ } else {
+ for (layer_id = 0; layer_id < model->mvtvm.metadata.model.nb_layers;
+ layer_id++) {
+ if (model->layer[layer_id].type == ML_CNXK_LAYER_TYPE_MRVL) {
+ plt_ml_dbg(
+ "layer_id = %u: wb_pages = %u, scratch_pages = %u\n",
+ layer_id,
+ model->layer[layer_id].glow.ocm_map.wb_pages,
+ model->layer[layer_id].glow.ocm_map.scratch_pages);
+ }
+ }
+#endif
+ }
+
+ if (model->type == ML_CNXK_MODEL_TYPE_GLOW)
+ cn10k_ml_model_unload(cnxk_mldev, model);
+#ifdef RTE_MLDEV_CNXK_ENABLE_MVTVM
+ else {
+ mvtvm_ml_model_unload(cnxk_mldev, model);
+ return -ENOMEM;
+ }
+#endif
+ }
plt_spinlock_init(&model->lock);
model->state = ML_CNXK_MODEL_STATE_LOADED;
cnxk_mldev->nb_models_loaded++;