@@ -385,4 +385,127 @@ rte_ml_dev_queue_pair_setup(int16_t dev_id, uint16_t queue_pair_id,
return (*dev->dev_ops->dev_queue_pair_setup)(dev, queue_pair_id, qp_conf, socket_id);
}
+int
+rte_ml_model_load(int16_t dev_id, struct rte_ml_model_params *params, uint16_t *model_id)
+{
+ struct rte_ml_dev *dev;
+
+ if (!rte_ml_dev_is_valid_dev(dev_id)) {
+ RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id);
+ return -EINVAL;
+ }
+
+ dev = rte_ml_dev_pmd_get_dev(dev_id);
+ if (*dev->dev_ops->model_load == NULL)
+ return -ENOTSUP;
+
+ if (params == NULL) {
+ RTE_MLDEV_LOG(ERR, "Dev %d, params cannot be NULL\n", dev_id);
+ return -EINVAL;
+ }
+
+ if (model_id == NULL) {
+ RTE_MLDEV_LOG(ERR, "Dev %d, model_id cannot be NULL\n", dev_id);
+ return -EINVAL;
+ }
+
+ return (*dev->dev_ops->model_load)(dev, params, model_id);
+}
+
+int
+rte_ml_model_unload(int16_t dev_id, uint16_t model_id)
+{
+ struct rte_ml_dev *dev;
+
+ if (!rte_ml_dev_is_valid_dev(dev_id)) {
+ RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id);
+ return -EINVAL;
+ }
+
+ dev = rte_ml_dev_pmd_get_dev(dev_id);
+ if (*dev->dev_ops->model_unload == NULL)
+ return -ENOTSUP;
+
+ return (*dev->dev_ops->model_unload)(dev, model_id);
+}
+
+int
+rte_ml_model_start(int16_t dev_id, uint16_t model_id)
+{
+ struct rte_ml_dev *dev;
+
+ if (!rte_ml_dev_is_valid_dev(dev_id)) {
+ RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id);
+ return -EINVAL;
+ }
+
+ dev = rte_ml_dev_pmd_get_dev(dev_id);
+ if (*dev->dev_ops->model_start == NULL)
+ return -ENOTSUP;
+
+ return (*dev->dev_ops->model_start)(dev, model_id);
+}
+
+int
+rte_ml_model_stop(int16_t dev_id, uint16_t model_id)
+{
+ struct rte_ml_dev *dev;
+
+ if (!rte_ml_dev_is_valid_dev(dev_id)) {
+ RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id);
+ return -EINVAL;
+ }
+
+ dev = rte_ml_dev_pmd_get_dev(dev_id);
+ if (*dev->dev_ops->model_stop == NULL)
+ return -ENOTSUP;
+
+ return (*dev->dev_ops->model_stop)(dev, model_id);
+}
+
+int
+rte_ml_model_info_get(int16_t dev_id, uint16_t model_id, struct rte_ml_model_info *model_info)
+{
+ struct rte_ml_dev *dev;
+
+ if (!rte_ml_dev_is_valid_dev(dev_id)) {
+ RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id);
+ return -EINVAL;
+ }
+
+ dev = rte_ml_dev_pmd_get_dev(dev_id);
+ if (*dev->dev_ops->model_info_get == NULL)
+ return -ENOTSUP;
+
+ if (model_info == NULL) {
+ RTE_MLDEV_LOG(ERR, "Dev %d, model_id %u, model_info cannot be NULL\n", dev_id,
+ model_id);
+ return -EINVAL;
+ }
+
+ return (*dev->dev_ops->model_info_get)(dev, model_id, model_info);
+}
+
+int
+rte_ml_model_params_update(int16_t dev_id, uint16_t model_id, void *buffer)
+{
+ struct rte_ml_dev *dev;
+
+ if (!rte_ml_dev_is_valid_dev(dev_id)) {
+ RTE_MLDEV_LOG(ERR, "Invalid dev_id = %d\n", dev_id);
+ return -EINVAL;
+ }
+
+ dev = rte_ml_dev_pmd_get_dev(dev_id);
+ if (*dev->dev_ops->model_params_update == NULL)
+ return -ENOTSUP;
+
+ if (buffer == NULL) {
+ RTE_MLDEV_LOG(ERR, "Dev %d, buffer cannot be NULL\n", dev_id);
+ return -EINVAL;
+ }
+
+ return (*dev->dev_ops->model_params_update)(dev, model_id, buffer);
+}
+
RTE_LOG_REGISTER_DEFAULT(rte_ml_dev_logtype, INFO);
@@ -151,6 +151,110 @@ typedef int (*mldev_queue_pair_setup_t)(struct rte_ml_dev *dev, uint16_t queue_p
*/
typedef int (*mldev_queue_pair_release_t)(struct rte_ml_dev *dev, uint16_t queue_pair_id);
+/**
+ * @internal
+ *
+ * Function used to load an ML model.
+ *
+ * @param dev
+ * ML device pointer.
+ * @param params
+ * Model load params.
+ * @param model_id
+ * Model ID returned by the library.
+ *
+ * @return
+ * - 0 on success.
+ * - < 0, error on failure.
+ */
+typedef int (*mldev_model_load_t)(struct rte_ml_dev *dev, struct rte_ml_model_params *params,
+ uint16_t *model_id);
+
+/**
+ * @internal
+ *
+ * Function used to unload an ML model.
+ *
+ * @param dev
+ * ML device pointer.
+ * @param model_id
+ * Model ID to use.
+ *
+ * @return
+ * - 0 on success.
+ * - < 0, error on failure.
+ */
+typedef int (*mldev_model_unload_t)(struct rte_ml_dev *dev, uint16_t model_id);
+
+/**
+ * @internal
+ *
+ * Function used to start an ML model.
+ *
+ * @param dev
+ * ML device pointer.
+ * @param model_id
+ * Model ID to use.
+ *
+ * @return
+ * - 0 on success.
+ * - <0, error on failure.
+ */
+typedef int (*mldev_model_start_t)(struct rte_ml_dev *dev, uint16_t model_id);
+
+/**
+ * @internal
+ *
+ * Function used to stop an ML model.
+ *
+ * @param dev
+ * ML device pointer.
+ * @param model_id
+ * Model ID to use.
+ *
+ * @return
+ * - 0 on success.
+ * - <0, error on failure.
+ */
+typedef int (*mldev_model_stop_t)(struct rte_ml_dev *dev, uint16_t model_id);
+
+/**
+ * @internal
+ *
+ * Get info about a model.
+ *
+ * @param dev
+ * ML device pointer.
+ * @param model_id
+ * Model ID to use.
+ * @param model_info
+ * Pointer to model info structure.
+ *
+ * @return
+ * - 0 on success.
+ * - <0, error on failure.
+ */
+typedef int (*mldev_model_info_get_t)(struct rte_ml_dev *dev, uint16_t model_id,
+ struct rte_ml_model_info *model_info);
+
+/**
+ * @internal
+ *
+ * Update model params.
+ *
+ * @param dev
+ * ML device pointer.
+ * @param model_id
+ * Model ID to use.
+ * @param buffer
+ * Pointer to model params.
+ *
+ * @return
+ * - 0 on success.
+ * - <0, error on failure.
+ */
+typedef int (*mldev_model_params_update_t)(struct rte_ml_dev *dev, uint16_t model_id, void *buffer);
+
/**
* @internal
*
@@ -177,6 +281,24 @@ struct rte_ml_dev_ops {
/** Release a device queue pair. */
mldev_queue_pair_release_t dev_queue_pair_release;
+
+ /** Load an ML model. */
+ mldev_model_load_t model_load;
+
+ /** Unload an ML model. */
+ mldev_model_unload_t model_unload;
+
+ /** Start an ML model. */
+ mldev_model_start_t model_start;
+
+ /** Stop an ML model. */
+ mldev_model_stop_t model_stop;
+
+ /** Get model information. */
+ mldev_model_info_get_t model_info_get;
+
+ /** Update model params. */
+ mldev_model_params_update_t model_params_update;
};
/**
@@ -12,6 +12,12 @@ EXPERIMENTAL {
rte_ml_dev_socket_id;
rte_ml_dev_start;
rte_ml_dev_stop;
+ rte_ml_model_info_get;
+ rte_ml_model_load;
+ rte_ml_model_params_update;
+ rte_ml_model_start;
+ rte_ml_model_stop;
+ rte_ml_model_unload;
local: *;
};