Help us learn about your current experience with the documentation. Take the survey.

MLflow 客户端兼容性

  • 等级:Free, Premium, Ultimate
  • 提供:GitLab.com, GitLab Self-Managed, GitLab Dedicated

MLflow 是一个流行的开源机器学习实验跟踪工具。 GitLab 模型实验跟踪 和 GitLab 模型注册表 与 MLflow 客户端兼容。设置只需对现有代码进行少量修改。

启用 MLflow 客户端集成

先决条件:

  • 一个兼容 GitLab 的 Python 客户端:
  • 一个具有至少 Developer 角色和 api 范围的 个人项目 访问令牌
  • 项目 ID。要查找项目 ID:
    1. 在左侧边栏,选择 Search or go to 并找到您的项目。
    2. 选择 Settings > General

要从本地环境使用 MLflow 客户端兼容性:

  1. 在运行代码的主机上设置跟踪 URI 和令牌环境变量。 这可以是您的本地环境、CI 管道或远程主机。例如:

    export MLFLOW_TRACKING_URI="<your gitlab endpoint>/api/v4/projects/<your project id>/ml/mlflow"
    export MLFLOW_TRACKING_TOKEN="<your_access_token>"
  2. 如果训练代码包含对 mlflow.set_tracking_uri() 的调用,请将其移除。

在模型注册表中,您可以通过选择右上角的溢出菜单中的垂直省略号( ellipsis_v )来复制跟踪 URI。

模型实验

运行训练代码时,可以使用 MLflow 客户端在 GitLab 上创建实验、运行、 模型、模型版本,记录参数、指标、元数据和工件。

实验记录后,它们将列在 /<your project>/-/ml/experiments 下。

运行会被注册,您可以通过选择实验、模型或模型版本来探索它们。

创建实验

import mlflow

# 创建新实验
experiment_id = mlflow.create_experiment(name="<your_experiment>")

# 设置活动实验,如果不存在也会创建新实验
mlflow.set_experiment(experiment_name="<your_experiment>")

创建运行

import mlflow

# 创建运行需要实验 ID 或活动实验
mlflow.set_experiment(experiment_name="<your_experiment>")

# 可以使用或不用上下文管理器创建运行
with mlflow.start_run() as run:
    print(run.info.run_id)
    # 您的训练代码

with mlflow.start_run():
    # 您的训练代码

记录参数和指标

import mlflow

mlflow.set_experiment(experiment_name="<your_experiment>")

with mlflow.start_run():
    # 参数键在运行范围内必须是唯一的
    mlflow.log_param(key="param_1", value=1)

    # 指标可以在运行过程中更新
    mlflow.log_metric(key="metrics_1", value=1)
    mlflow.log_metric(key="metrics_1", value=2)

记录工件

import mlflow

mlflow.set_experiment(experiment_name="<your_experiment>")

with mlflow.start_run():
    # 纯文本文件可以使用 `log_text` 作为工件记录
    mlflow.log_text('Hello, World!', artifact_file='hello.txt')

    mlflow.log_artifact(
        local_path='<local/path/to/file.txt>',
        artifact_path='<可选的相对路径,用于记录工件>'
    )

记录模型

可以使用支持的 MLflow 模型类型 之一记录模型。 使用模型类型记录元数据,可以更轻松地在不同工具和环境之间管理、加载和部署模型。

import mlflow
from sklearn.ensemble import RandomForestClassifier

mlflow.set_experiment(experiment_name="<your_experiment>")

with mlflow.start_run():
    # 创建并训练一个简单模型
    model = RandomForestClassifier(n_estimators=10, random_state=42)
    model.fit(X_train, y_train)

    # 使用 MLflow sklearn 模型类型记录模型
    mlflow.sklearn.log_model(model, artifact_path="")

加载运行

您可以从 GitLab 模型注册表加载运行,例如用于进行预测。

import mlflow
import mlflow.pyfunc

run_id = "<your_run_id>"
download_path = "models"  # 本地下载文件夹

mlflow.pyfunc.load_model(f"runs:/{run_id}/", dst_path=download_path)

sample_input = [[1,0,3,4],[2,0,1,2]]
model.predict(data=sample_input)

将运行关联到 CI/CD 作业

如果您的训练代码正在从 CI/CD 作业运行,GitLab 可以使用该信息来增强运行元数据。要将运行关联到 CI/CD 作业:

  1. 项目 CI 变量 中包含以下变量:

    • MLFLOW_TRACKING_URI: "<your gitlab endpoint>/api/v4/projects/<your project id>/ml/mlflow"
    • MLFLOW_TRACKING_TOKEN: <your_access_token>
  2. 在运行执行上下文中的训练代码中,添加以下代码片段:

    import os
    import mlflow
    
    with mlflow.start_run(run_name=f"Run {index}"):
      # 您的训练代码
    
      # 要包含的代码片段开始
      if os.getenv('GITLAB_CI'):
        mlflow.set_tag('gitlab.CI_JOB_ID', os.getenv('CI_JOB_ID'))
      # 要包含的代码片段结束

模型注册表

您还可以使用 MLflow 客户端管理模型和模型版本。模型注册在 /<your project>/-/ml/models 下。

模型

创建模型

from mlflow import MlflowClient

client = MlflowClient()
model_name = '<your_model_name>'
description = '模型描述'
model = client.create_registered_model(model_name, description=description)

注意事项

  • create_registered_model 参数 tags 会被忽略。
  • name 在项目内必须唯一。
  • name 不能是现有实验的名称。

获取模型

from mlflow import MlflowClient

client = MlflowClient()
model_name = '<your_model_name>'
model = client.get_registered_model(model_name)

更新模型

from mlflow import MlflowClient

client = MlflowClient()
model_name = '<your_model_name>'
description = '新描述'
client.update_registered_model(model_name, description=description)

删除模型

from mlflow import MlflowClient

client = MlflowClient()
model_name = '<your_model_name>'
client.delete_registered_model(model_name)

将运行记录到模型

每个模型都有一个关联的实验,名称前缀为 [model]。 要将运行记录到模型,请使用传递正确名称的实验:

from mlflow import MlflowClient

client = MlflowClient()
model_name = '<your_model_name>'
exp = client.get_experiment_by_name(f"[model]{model_name}")
run = client.create_run(exp.experiment_id)

模型版本

创建模型版本

from mlflow import MlflowClient

client = MlflowClient()
model_name = '<your_model_name>'
description = '模型版本描述'
model_version = client.create_model_version(model_name, source="", description=description)

如果未传递版本参数,它将从最新上传的版本自动递增。您可以在创建模型版本时通过传递标签来设置版本。版本必须遵循 SemVer 格式。

from mlflow import MlflowClient

client = MlflowClient()
model_name = '<your_model_name>'
version = '<your_version>'
tags = { "gitlab.version": version }
client.create_model_version(model_name, version, description=description, tags=tags)

注意事项

  • 参数 run_id 会被忽略。每个模型版本都表现为一个运行。从运行创建模型版本尚不支持。
  • 参数 source 会被忽略。GitLab 将为模型版本文件创建包位置。
  • 参数 run_link 会被忽略。
  • 参数 await_creation_for 会被忽略。

更新模型版本

from mlflow import MlflowClient

client = MlflowClient()
model_name = '<your_model_name>'
version = '<your_version>'
description = '新描述'
client.update_model_version(model_name, version, description=description)

获取模型版本

from mlflow import MlflowClient

client = MlflowClient()
model_name = '<your_model_name>'
version = '<your_version>'
client.get_model_version(model_name, version)

获取模型的最新版本

from mlflow import MlflowClient

client = MlflowClient()
model_name = '<your_model_name>'
client.get_latest_versions(model_name)

注意事项

  • 参数 stages 会被忽略。
  • 版本按最高语义版本排序。

加载模型版本

from mlflow import MlflowClient

client = MlflowClient()
model_name = '<your_model_name>'
version = '<your_version>'  # 例如:'1.0.0'

# 或者搜索版本
version = mlflow.search_registered_models(filter_string="name='{model_name}'")[0].latest_versions[0].version

model = mlflow.pyfunc.load_model(f"models:/{model_name}/{latest_version}")

# 或者加载最新版本
model = mlflow.pyfunc.load_model(f"models:/{model_name}/latest")

向模型版本记录指标和参数

每个模型版本也是一个运行,允许用户记录参数和指标。运行 ID 可以在 GitLab 的模型版本页面找到,或使用 MLflow 客户端:

from mlflow import MlflowClient

client = MlflowClient()
model_name = '<your_model_name>'
version = '<your_version>'
model_version = client.get_model_version(model_name, version)
run_id = model_version.run_id

# 您的训练代码

client.log_metric(run_id, '<metric_name>', '<metric_value>')
client.log_param(run_id, '<param_name>', '<param_value>')
client.log_batch(run_id, metric_list, param_list, tag_list)

由于每个文件有 5 GB 的大小限制,您必须对较大的模型进行分区。

向模型版本记录工件

GitLab 创建一个包,MLflow 客户端可以使用它来上传文件。

from mlflow import MlflowClient

client = MlflowClient()
model_name = '<your_model_name>'
version = '<your_version>'
model_version = client.get_model_version(model_name, version)
run_id = model_version.run_id

# 您的训练代码

client.log_artifact(run_id, '<local/path/to/file.txt>', artifact_path="")
client.log_figure(run_id, figure, artifact_file="my_plot.png")
client.log_dict(run_id, my_dict, artifact_file="my_dict.json")
client.log_image(run_id, image, artifact_file="image.png")

工件随后将在 https/<your project>/-/ml/models/<model_id>/versions/<version_id> 下可用。

将模型版本链接到 CI/CD 作业

与运行类似,也可以将模型版本链接到 CI/CD 作业:

import os
from mlflow import MlflowClient

client = MlflowClient()
model_name = '<your_model_name>'
version = '<your_version>'
model_version = client.get_model_version(model_name, version)
run_id = model_version.run_id

# 您的训练代码

if os.getenv('GITLAB_CI'):
    client.set_tag(model_version.run_id, 'gitlab.CI_JOB_ID', os.getenv('CI_JOB_ID'))

支持的 MLflow 客户端方法和注意事项

GitLab 支持 MLflow 客户端的以下方法。更多信息请参见 MLflow 文档。以下方法的 MlflowClient 对应方法也支持,具有相同的注意事项。

方法 支持情况 版本添加 备注
create_experiment 15.11
get_experiment 15.11
get_experiment_by_name 15.11
delete_experiment 17.5
set_experiment 15.11
get_run 15.11
delete_run 17.5
start_run 15.11 (16.3) 如果未提供名称,运行将获得随机昵称。
search_runs 15.11 (16.4) experiment_ids 仅支持单个实验 ID,支持按列或指标排序。
log_artifact 有条件支持 15.11 (15.11) artifact_path 必须为空。不支持目录。
log_artifacts 有条件支持 15.11 (15.11) artifact_path 必须为空。不支持目录。
log_batch 15.11
log_metric 15.11
log_metrics 15.11
log_param 15.11
log_params 15.11
log_figure 15.11
log_image 15.11
log_text 有条件支持 15.11 (15.11) 不支持目录。
log_dict 有条件支持 15.11 (15.11) 不支持目录。
set_tag 15.11
set_tags 15.11
set_terminated 15.11
end_run 15.11
update_run 15.11
log_model 部分支持 15.11 (15.11) 保存工件,但不保存模型数据。artifact_path 必须为空。
load_model 17.5
download_artifacts 17.9
list_artifacts 17.9

其他 MLflowClient 方法:

方法 支持情况 版本添加 备注
create_registered_model 有条件支持 16.8 参见注意事项
get_registered_model 16.8
delete_registered_model 16.8
update_registered_model 16.8
create_model_version 有条件支持 16.8 参见注意事项
get_model_version 16.8
get_latest_versions 有条件支持 16.8 参见注意事项
update_model_version 16.8
create_registered_model 16.8
create_registered_model 16.8

已知问题

  • 未在 支持的方法 中列出的 MLflow 客户端方法可能仍然有效,但尚未经过测试。
  • 在创建实验和运行时,ExperimentTags 会被存储,即使它们不会显示。