initial commit

This commit is contained in:
Pavel Dmitriev 2024-05-15 11:14:29 +03:00
commit 613fe39d8b
10 changed files with 1316 additions and 0 deletions

129
api.go Normal file
View File

@ -0,0 +1,129 @@
package mlflow
import (
"encoding/json"
"fmt"
"io"
"net/http"
"reflect"
"strings"
"time"
)
const ViewTypeAll = "ALL"
const ViewTypeActive = "ACTIVE_ONLY"
const ViewTypeDeleted = "DELETED_ONLY"
type ViewType string
type MlflowApiError struct {
ErrorCode string `json:"error_code,omitempty"`
ErrorDesc string `json:"message,omitempty"`
}
func (e MlflowApiError) Error() string {
return fmt.Sprint("Request Error " + string(e.ErrorCode) + ": " + e.ErrorDesc)
}
type Timestamp int64
func (t *Timestamp) Time() time.Time {
return time.Unix(int64(*t), 0)
}
func (t *Timestamp) Set(time time.Time) Timestamp {
return Timestamp(time.Unix())
}
type TimestampMs int64
func (t *TimestampMs) Time() time.Time {
return time.Unix(int64(*t/1000), 0)
}
func (t *TimestampMs) Set(time time.Time) TimestampMs {
return TimestampMs(time.Unix() * 1000)
}
func apiReadReply(resp *http.Response, err error, ref any) error {
if err != nil {
fmt.Println(err.Error())
return err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
fmt.Println(err.Error())
return err
}
if resp.StatusCode != 200 && resp.StatusCode != 201 {
// Error code return
err := MlflowApiError{}
err2 := json.Unmarshal(body, &err)
if err2 != nil {
fmt.Println("Error when parsing reply: ", err2.Error(), "\n", string(body))
return err2
}
return err
}
err = json.Unmarshal(body, &ref)
//fmt.Println(string(body))
if err != nil {
return err
}
return nil
}
type typeTag struct {
FieldName string
Omitempty bool
}
func parseJsonTag(tag reflect.StructTag, FieldName string) typeTag {
rv := typeTag{}
tag1 := strings.Split(string(tag), ":")
if tag1[0] == "json" {
tag2 := strings.Split(strings.ReplaceAll(tag1[1], `"`, ""), ",")
rv.FieldName = tag2[0]
rv.Omitempty = (len(tag2) == 2 && tag2[1] == "omitempty")
}
if rv.FieldName == "" {
rv.FieldName = FieldName
}
return rv
}
func UrlEncode(s any) string {
rv := ""
rvCount := 0
v := reflect.ValueOf(s)
typeOfV := v.Type()
values := make([]interface{}, v.NumField())
for i := 0; i < v.NumField(); i++ {
values[i] = v.Field(i).Interface()
name := typeOfV.Field(i).Name
tag := parseJsonTag(typeOfV.Field(i).Tag, name)
val := v.Field(i)
if tag.FieldName != "-" && (!tag.Omitempty || fmt.Sprint(val) != "") {
if rvCount > 0 {
rv += "&"
} else {
rv += "?"
}
rv += fmt.Sprintf("%s=%s", tag.FieldName, fmt.Sprint(val))
rvCount++
}
}
return rv
}

36
artifacts.go Normal file
View File

@ -0,0 +1,36 @@
package mlflow
import "net/http"
type ListArtifactsRequest struct {
RunId string `json:"run_id,omitempty"`
Path string `json:"path,omitempty"`
PageToken string `json:"page_token,omitempty"`
}
func (l *ListArtifactsRequest) UrlEncode() string {
return UrlEncode(*l)
}
type FileInfo struct {
Path string `json:"path,omitempty"`
IsDir bool `json:"is_dir,omitempty"`
FileSize int64 `json:"file_size,omitempty"`
}
type ListArtifactsResponse struct {
RootUri string `json:"root_uri,omitempty"`
Files []FileInfo `json:"files,omitempty"`
NextPageToken string `json:"next_page_token,omitempty"`
}
func ListArtefacts(conf Config, req ListArtifactsRequest) (*ListArtifactsResponse, error) {
resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/artifacts/list" + req.UrlEncode())
rv := ListArtifactsResponse{}
err = apiReadReply(resp, err, &rv)
if err != nil {
return nil, err
}
return &rv, nil
}

6
config.go Normal file
View File

@ -0,0 +1,6 @@
package mlflow
type Config struct {
ApiURI string
IgnoreSSL bool
}

97
experiment.go Normal file
View File

@ -0,0 +1,97 @@
package mlflow
import (
"bytes"
"encoding/json"
"net/http"
)
func SearchExperiments(conf Config, opts ExperimentSearchOptions) (*ExperimentSearchReply, error) {
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/experiments/search", "application/json", opts.getReader())
exp := ExperimentSearchReply{}
err = apiReadReply(resp, err, &exp)
if err != nil {
return nil, err
}
return &exp, nil
}
func GetExperiment(conf Config, id string) (*Experiment, error) {
resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/experiments/get" + "?experiment_id=" + id)
exp := ExperimentGetReply{}
err = apiReadReply(resp, err, &exp)
if err != nil {
return nil, err
}
return &exp.Experiment, nil
}
func GetExperimentByName(conf Config, name string) (*Experiment, error) {
resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/experiments/get-by-name" + "?experiment_name=" + name)
exp := ExperimentGetReply{}
err = apiReadReply(resp, err, &exp)
if err != nil {
return nil, err
}
return &exp.Experiment, nil
}
func DeleteExperiment(conf Config, id string) error {
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/experiments/delete", "application/json", bytes.NewReader([]byte(`{"experiment_id":"`+id+`"}`)))
err = apiReadReply(resp, err, nil)
if err != nil {
return err
}
return nil
}
func RestoreExperiment(conf Config, id string) (*Experiment, error) {
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/experiments/restore", "application/json", bytes.NewReader([]byte(`{"experiment_id":"`+id+`"}`)))
err = apiReadReply(resp, err, nil)
if err != nil {
return nil, err
}
return GetExperiment(conf, id)
}
func UpdateExperiment(conf Config, id string, newName string) (*Experiment, error) {
args := `{"experiment_id":"` + id + `","new_name":"` + newName + `"}`
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/experiments/update", "application/json", bytes.NewReader([]byte(args)))
err = apiReadReply(resp, err, nil)
if err != nil {
return nil, err
}
return GetExperiment(conf, id)
}
func CreateExperiment(conf Config, exp *Experiment) (*Experiment, error) {
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/experiments/create", "application/json", bytes.NewReader([]byte(exp.serialize())))
xp := ExperimentCreateReply{}
err = apiReadReply(resp, err, &xp)
if err != nil {
return nil, err
}
return GetExperiment(conf, xp.ExperimentId)
}
func SetExperimentTagTag(conf Config, opts SetExperimentTagRequest) error {
args, err := json.Marshal(opts)
if err != nil {
return err
}
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/experiments/set-experiment-tag", "application/json", bytes.NewReader([]byte(args)))
err = apiReadReply(resp, err, nil)
if err != nil {
return err
}
return nil
}

69
experiment_types.go Normal file
View File

@ -0,0 +1,69 @@
package mlflow
import (
"bytes"
"encoding/json"
)
type ExperimentSearchReply struct {
Experiments []Experiment `json:"experiments"`
NextPageToken string `json:"next_page_token,omitempty"`
}
type ExperimentGetReply struct {
Experiment Experiment `json:"experiment"`
}
type ExperimentCreateReply struct {
ExperimentId string `json:"experiment_id"`
}
type Experiment struct {
Id string `json:"experiment_id,omitempty"`
Name string `json:"name,omitempty"`
ArtifactLocation string `json:"artifact_location,omitempty"`
LifecycleStage string `json:"lifecycle_stage,omitempty"`
LastUpdate Timestamp `json:"last_update_time,omitempty"`
CreationTime Timestamp `json:"creation_time,omitempty"`
Tage []ExperimentTag `json:"tags,omitempty"`
}
func (e *Experiment) serialize() []byte {
j, _ := json.Marshal(e)
return j
}
type ExperimentTag struct {
Key string `json:"key,omitempty"`
Value string `json:"value,omitempty"`
}
type ExperimentSearchOptions struct {
MaxResults int64 `json:"max_results,omitempty"`
Filter string `json:"filter,omitempty"`
PageToken string `json:"page_token,omitempty"`
OrderBy string `json:"order_by,omitempty"`
ViewType ViewType `json:"view_type,omitempty"`
}
func (e *ExperimentSearchOptions) __init() *ExperimentSearchOptions {
if e.MaxResults == 0 {
e.MaxResults = 1000
}
return e
}
func (e *ExperimentSearchOptions) serialize() []byte {
j, _ := json.Marshal(e)
return j
}
func (e *ExperimentSearchOptions) getReader() *bytes.Reader {
return bytes.NewReader(e.__init().serialize())
}
type SetExperimentTagRequest struct {
ExperimentTag
ExperimentId string `json:"experiment_id,omitempty"`
}

3
go.mod Normal file
View File

@ -0,0 +1,3 @@
module mxfox.ru/mlops/mlflow-client-go
go 1.21.1

287
model.go Normal file
View File

@ -0,0 +1,287 @@
package mlflow
import (
"bytes"
"net/http"
)
func CreateRegisteredModel(conf Config, req CreateRegisteredModelRequest) (*RegisteredModel, error) {
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/registered-models/create", "application/json", bytes.NewReader([]byte(req.serialize())))
r := CreateRegisteredModelResponse{}
err = apiReadReply(resp, err, &r)
if err != nil {
return nil, err
}
return &r.RegisteredModel, nil
}
func GetRegisteredModel(conf Config, name string) (*RegisteredModel, error) {
resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/registered-models/get?name=" + name)
r := GetRegisteredModelResponse{}
err = apiReadReply(resp, err, &r)
if err != nil {
return nil, err
}
return &r.RegisteredModel, nil
}
func RenameRegisteredModel(conf Config, req RenameRegisteredModelRequrest) (*RegisteredModel, error) {
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/registered-models/rename", "application/json", bytes.NewReader([]byte(req.serialize())))
r := RenameRegisteredModelResponse{}
err = apiReadReply(resp, err, &r)
if err != nil {
return nil, err
}
return &r.RegisteredModel, nil
}
func DeleteRegisteredModel(conf Config, name string) error {
client := &http.Client{}
req, err := http.NewRequest("DELETE", conf.ApiURI+"/api/2.0/mlflow/registered-models/delete", bytes.NewReader([]byte(`{"name":"`+name+`"}`)))
if err != nil {
return err
}
req.Header.Add("ContentType", "application/json")
resp, err := client.Do(req)
err = apiReadReply(resp, err, nil)
if err != nil {
return err
}
return nil
}
func GetLatestModelVersions(conf Config, req GetLatestModelVersionsRequest) (*[]ModelVersion, error) {
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/registered-models/get-latest-versions", "application/json", bytes.NewReader([]byte(req.serialize())))
r := GetLatestModelVersionsReponse{}
err = apiReadReply(resp, err, &r)
if err != nil {
return nil, err
}
return &r.ModelVersions, nil
}
func GetModelVersion(conf Config, req GetModelVersionRequest) (*ModelVersion, error) {
resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/model-versions/get" + UrlEncode(req))
r := GetModelVersionResponse{}
err = apiReadReply(resp, err, &r)
if err != nil {
return nil, err
}
return &r.ModelVersion, nil
}
func DeleteModelVersion(conf Config, req DeleteModelVersionRequest) error {
client := &http.Client{}
hreq, err := http.NewRequest("DELETE", conf.ApiURI+"/api/2.0/mlflow/model-versions/delete", bytes.NewReader([]byte(req.serialize())))
if err != nil {
return err
}
hreq.Header.Add("ContentType", "application/json")
resp, err := client.Do(hreq)
err = apiReadReply(resp, err, nil)
if err != nil {
return err
}
return nil
}
func CreateModelVersion(conf Config, req CreateModelVersionRequest) (*ModelVersion, error) {
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/model-versions/create", "application/json", bytes.NewReader([]byte(req.serialize())))
r := CreateModelVersionResponse{}
err = apiReadReply(resp, err, &r)
if err != nil {
return nil, err
}
return &r.ModelVersion, nil
}
func GetDownloadUriForModelVersionArtifacts(conf Config, req GetDownloadUriForModelVersionArtifactsRequest) (*string, error) {
resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/model-versions/get-download-uri" + UrlEncode(req))
r := GetDownloadUriForModelVersionArtifactsResponse{}
err = apiReadReply(resp, err, &r)
if err != nil {
return nil, err
}
return &r.ArtifactUri, nil
}
func UpdateModelVersion(conf Config, req UpdateModelVersionRequest) (*ModelVersion, error) {
client := &http.Client{}
hreq, err := http.NewRequest("PATCH", conf.ApiURI+"/api/2.0/mlflow/model-versions/update", bytes.NewReader([]byte(req.serialize())))
if err != nil {
return nil, err
}
hreq.Header.Add("ContentType", "application/json")
resp, err := client.Do(hreq)
r := UpdateModelVersionResponse{}
err = apiReadReply(resp, err, &r)
if err != nil {
return nil, err
}
return &r.ModelVersion, nil
}
func UpdateRegisteredModel(conf Config, req UpdateRegisteredModelRequest) (*RegisteredModel, error) {
client := &http.Client{}
hreq, err := http.NewRequest("PATCH", conf.ApiURI+"/api/2.0/mlflow/registered-models/update", bytes.NewReader([]byte(req.serialize())))
if err != nil {
return nil, err
}
hreq.Header.Add("ContentType", "application/json")
resp, err := client.Do(hreq)
r := UpdateRegisteredModelResponse{}
err = apiReadReply(resp, err, &r)
if err != nil {
return nil, err
}
return &r.RegisteredModel, nil
}
func TransitionModelVersionStage(conf Config, req TransitionModelVersionStageRequest) (*ModelVersion, error) {
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/model-versions/transition-stage", "application/json", bytes.NewReader([]byte(req.serialize())))
r := TransitionModelVersionStageResponse{}
err = apiReadReply(resp, err, &r)
if err != nil {
return nil, err
}
return &r.ModelVersion, nil
}
func SearchModelVersions(conf Config, req SearchModelVersionsRequest) (*SearchModelVersionsResponse, error) {
resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/model-versions/search" + UrlEncode(*req.__init()))
r := SearchModelVersionsResponse{}
err = apiReadReply(resp, err, &r)
if err != nil {
return nil, err
}
return &r, nil
}
func SearchRegisteredModels(conf Config, req SearchRegisteredModelsRequest) (*SearchRegisteredModelsResponse, error) {
resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/registered-models/search" + UrlEncode(*req.__init()))
r := SearchRegisteredModelsResponse{}
err = apiReadReply(resp, err, &r)
if err != nil {
return nil, err
}
return &r, nil
}
func SetRegisteredModelTag(conf Config, req SetRegisteredModelTagRequest) error {
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/registered-models/set-tag", "application/json", bytes.NewReader([]byte(req.serialize())))
err = apiReadReply(resp, err, nil)
if err != nil {
return err
}
return nil
}
func SetModelVersionTag(conf Config, req SetModelVersionTagRequest) error {
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/model-versions/set-tag", "application/json", bytes.NewReader([]byte(req.serialize())))
err = apiReadReply(resp, err, nil)
if err != nil {
return err
}
return nil
}
func DeleteModelVersionTag(conf Config, req DeleteModelVersionTagRequest) error {
client := &http.Client{}
hreq, err := http.NewRequest("DELETE", conf.ApiURI+"/api/2.0/mlflow/model-versions/delete-tag", bytes.NewReader([]byte(req.serialize())))
if err != nil {
return err
}
hreq.Header.Add("ContentType", "application/json")
resp, err := client.Do(hreq)
err = apiReadReply(resp, err, nil)
if err != nil {
return err
}
return nil
}
func DeleteRegisteredModelTag(conf Config, req DeleteRegisteredModelTagRequest) error {
client := &http.Client{}
hreq, err := http.NewRequest("DELETE", conf.ApiURI+"/api/2.0/mlflow/registered-models/delete-tag", bytes.NewReader([]byte(req.serialize())))
if err != nil {
return err
}
hreq.Header.Add("ContentType", "application/json")
resp, err := client.Do(hreq)
err = apiReadReply(resp, err, nil)
if err != nil {
return err
}
return nil
}
func SetRegisteredModelAlias(conf Config, req SetRegisteredModelAliasRequest) error {
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/registered-models/alias", "application/json", bytes.NewReader([]byte(req.serialize())))
err = apiReadReply(resp, err, nil)
if err != nil {
return err
}
return nil
}
func DeleteRegisteredModelAlias(conf Config, req DeleteRegisteredModelAliasRequest) error {
client := &http.Client{}
hreq, err := http.NewRequest("DELETE", conf.ApiURI+"/api/2.0/mlflow/registered-models/alias", bytes.NewReader([]byte(req.serialize())))
if err != nil {
return err
}
hreq.Header.Add("ContentType", "application/json")
resp, err := client.Do(hreq)
err = apiReadReply(resp, err, nil)
if err != nil {
return err
}
return nil
}
func GetModelVersionByAlias(conf Config, req GetModelVersionByAliasRequest) (*ModelVersion, error) {
resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/registered-models/alias" + UrlEncode(req))
r := GetModelVersionByAliasResponse{}
err = apiReadReply(resp, err, &r)
if err != nil {
return nil, err
}
return &r.ModelVersion, nil
}

302
model_types.go Normal file
View File

@ -0,0 +1,302 @@
package mlflow
import "encoding/json"
type ModelVersionStatus string
const ModelVersionStatusPendingRegistration = "PENDING_REGISTRATION"
const ModelVersionStatusFailedRegistration = "FAILED_REGISTRATION"
const ModelVersionStatusReady = "READY"
const ModelVersionStageNone = "None"
const ModelVersionStageStaging = "Staging"
const ModelVersionStageProduction = "Production"
const ModelVersionStageArchived = "Archived"
type RegisteredModelTag struct {
Key string `json:"key,omitempty"`
Value string `json:"value,omitempty"`
}
type ModelVersionTag struct {
Key string `json:"key,omitempty"`
Value string `json:"value,omitempty"`
}
type RegisteredModelAlias struct {
Alias string `json:"alias,omitempty"`
Version string `json:"version,omitempty"`
}
type CreateRegisteredModelRequest struct {
Name string `json:"name,omitempty"`
Tags []RegisteredModelTag `json:"tags,omitempty"`
Description string `json:"description,omitempty"`
}
func (r *CreateRegisteredModelRequest) serialize() []byte {
rv, _ := json.Marshal(r)
return rv
}
type CreateRegisteredModelResponse struct {
RegisteredModel RegisteredModel `json:"registered_model,omitempty"`
}
type ModelVersion struct {
Name string `json:"name,omitempty"`
Version string `json:"version,omitempty"`
CreationTimestamp Timestamp `json:"creation_timestamp,omitempty"`
LastUpdateTimestamp Timestamp `json:"last_update_timestamp,omitempty"`
UserId string `json:"user_id,omitempty"`
CurrentStage string `json:"current_stage,omitempty"`
Description string `json:"description,omitempty"`
Source string `json:"source,omitempty"`
RunId string `json:"run_id,omitempty"`
Status ModelVersionStatus `json:"status,omitempty"`
StatusMessage string `json:"status_message,omitempty"`
Tags []ModelVersionTag `json:"tags,omitempty"`
RunLink string `json:"run_link,omitempty"`
Aliases []string `json:"aliases,omitempty"`
}
type RegisteredModel struct {
Name string `json:"name,omitempty"`
CreationTimestamp Timestamp `json:"creation_timestamp,omitempty"`
LastUpdateTimestamp Timestamp `json:"last_update_timestamp,omitempty"`
UserId string `json:"user_id,omitempty"`
Description string `json:"description,omitempty"`
LatestVersions []ModelVersion `json:"latest_versions,omitempty"`
Tags []RegisteredModelTag `json:"tags,omitempty"`
Aliases []RegisteredModelAlias `json:"aliases,omitempty"`
}
type GetRegisteredModelResponse CreateRegisteredModelResponse
type RenameRegisteredModelRequrest struct {
Name string `json:"name,omitempty"`
NewName string `json:"new_name,omitempty"`
}
func (r *RenameRegisteredModelRequrest) serialize() []byte {
rv, _ := json.Marshal(r)
return rv
}
type RenameRegisteredModelResponse CreateRegisteredModelResponse
type GetLatestModelVersionsRequest struct {
Name string `json:"name,omitempty"`
Stages []string `json:"stages,omitempty"`
}
func (r *GetLatestModelVersionsRequest) serialize() []byte {
rv, _ := json.Marshal(r)
return rv
}
type GetLatestModelVersionsReponse struct {
ModelVersions []ModelVersion `json:"model_versions,omitempty"`
}
type GetModelVersionRequest struct {
Name string `json:"name,omitempty"`
Version string `json:"version,omitempty"`
}
type GetModelVersionResponse struct {
ModelVersion ModelVersion `json:"model_version,omitempty"`
}
type DeleteModelVersionRequest GetModelVersionRequest
func (r *DeleteModelVersionRequest) serialize() []byte {
rv, _ := json.Marshal(r)
return rv
}
type CreateModelVersionRequest struct {
Name string `json:"name,omitempty"`
Source string `json:"source,omitempty"`
RunId string `json:"run_id,omitempty"`
Tags []ModelVersionTag `json:"tags,omitempty"`
RunLink string `json:"run_link,omitempty"`
Description string `json:"description,omitempty"`
}
type CreateModelVersionResponse GetModelVersionResponse
func (r *CreateModelVersionRequest) serialize() []byte {
rv, _ := json.Marshal(r)
return rv
}
type GetDownloadUriForModelVersionArtifactsRequest GetModelVersionRequest
type GetDownloadUriForModelVersionArtifactsResponse struct {
ArtifactUri string `json:"artifact_uri,omitempty"`
}
type SearchModelVersionsRequest struct {
Filter string
MaxResults int64
OrderBy []string
PageToken string
}
func (e *SearchModelVersionsRequest) serialize() []byte {
j, _ := json.Marshal(e)
return j
}
func (e *SearchModelVersionsRequest) __init() *SearchModelVersionsRequest {
if e.MaxResults == 0 {
e.MaxResults = 1000
}
return e
}
type SearchModelVersionsResponse struct {
ModelVersions []ModelVersion `json:"model_versions,omitempty"`
NextPageToken string `json:"next_page_token,omitempty"`
}
type UpdateModelVersionRequest struct {
Name string `json:"name,omitempty"`
Version string `json:"version,omitempty"`
Description string `json:"description,omitempty"`
}
type UpdateModelVersionResponse GetModelVersionResponse
func (e *UpdateModelVersionRequest) serialize() []byte {
j, _ := json.Marshal(e)
return j
}
type UpdateRegisteredModelRequest struct {
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
}
type UpdateRegisteredModelResponse GetRegisteredModelResponse
func (e *UpdateRegisteredModelRequest) serialize() []byte {
j, _ := json.Marshal(e)
return j
}
type TransitionModelVersionStageRequest struct {
Name string `json:"name,omitempty"`
Version string `json:"version,omitempty"`
Stage string `json:"stage,omitempty"`
ArchiveExistingVersions bool `json:"archive_existing_versions,omitempty"`
}
func (e *TransitionModelVersionStageRequest) serialize() []byte {
j, _ := json.Marshal(e)
return j
}
type TransitionModelVersionStageResponse GetModelVersionResponse
type SearchRegisteredModelsRequest struct {
Filter string
MaxResults int64
OrderBy []string
PageToken string
}
func (e *SearchRegisteredModelsRequest) serialize() []byte {
j, _ := json.Marshal(e)
return j
}
func (e *SearchRegisteredModelsRequest) __init() *SearchRegisteredModelsRequest {
if e.MaxResults == 0 {
e.MaxResults = 1000
}
return e
}
type SearchRegisteredModelsResponse struct {
RegisteredModels []RegisteredModel `json:"registered_models,omitempty"`
NextPageToken string `json:"next_page_token,omitempty"`
}
type SetRegisteredModelTagRequest struct {
RegisteredModelTag
Name string `json:"name,omitempty"`
}
func (e *SetRegisteredModelTagRequest) serialize() []byte {
j, _ := json.Marshal(e)
return j
}
type SetModelVersionTagRequest struct {
ModelVersionTag
Name string `json:"name,omitempty"`
Version string `json:"version,omitempty"`
}
func (e *SetModelVersionTagRequest) serialize() []byte {
j, _ := json.Marshal(e)
return j
}
type DeleteRegisteredModelTagRequest struct {
Key string `json:"key,omitempty"`
Name string `json:"name,omitempty"`
}
func (e *DeleteRegisteredModelTagRequest) serialize() []byte {
j, _ := json.Marshal(e)
return j
}
type DeleteModelVersionTagRequest struct {
Key string `json:"key,omitempty"`
Name string `json:"name,omitempty"`
Version string `json:"version,omitempty"`
}
func (e *DeleteModelVersionTagRequest) serialize() []byte {
j, _ := json.Marshal(e)
return j
}
type SetRegisteredModelAliasRequest struct {
Name string `json:"name,omitempty"`
Alias string `json:"alias,omitempty"`
Version string `json:"version,omitempty"`
}
func (e *SetRegisteredModelAliasRequest) serialize() []byte {
j, _ := json.Marshal(e)
return j
}
type GetModelVersionByAliasRequest struct {
Name string `json:"name,omitempty"`
Alias string `json:"alias,omitempty"`
}
func (e *GetModelVersionByAliasRequest) serialize() []byte {
j, _ := json.Marshal(e)
return j
}
type DeleteRegisteredModelAliasRequest struct {
Name string `json:"name,omitempty"`
Alias string `json:"alias,omitempty"`
}
func (e *DeleteRegisteredModelAliasRequest) serialize() []byte {
j, _ := json.Marshal(e)
return j
}
type GetModelVersionByAliasResponse GetModelVersionResponse

201
run.go Normal file
View File

@ -0,0 +1,201 @@
package mlflow
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
)
type SearchRunReply struct {
Runs []Run `json:"runs,omitempty"`
NextPageToken string `json:"next_page_token,omitempty"`
}
type UpdateRunRequest struct {
RunId string `json:"run_id,omitempty"`
Status RunStatus `json:"status,omitempty"`
EndTime TimestampMs `json:"end_time,omitempty"`
RunName string `json:"run_name,omitempty"`
}
type UpdateRunReply struct {
RunInfo RunInfo `json:"run_info,omitempty"`
}
func CreateRun(conf Config, req CreateRunRequest) (*Run, error) {
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/runs/create", "application/json", bytes.NewReader([]byte(req.serialize())))
r := CreateRunReply{}
err = apiReadReply(resp, err, &r)
if err != nil {
return nil, err
}
return &r.Run, nil
}
func DeleteRun(conf Config, id string) error {
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/runs/delete", "application/json", bytes.NewReader([]byte(`{"run_id":"`+id+`"}`)))
err = apiReadReply(resp, err, nil)
if err != nil {
return err
}
return nil
}
func RestoreRun(conf Config, id string) (*Run, error) {
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/runs/restore", "application/json", bytes.NewReader([]byte(`{"run_id":"`+id+`"}`)))
err = apiReadReply(resp, err, nil)
if err != nil {
return nil, err
}
return GetRun(conf, id)
}
func GetRun(conf Config, id string) (*Run, error) {
resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/runs/get" + "?run_id=" + id)
run := GetRunReply{}
err = apiReadReply(resp, err, &run)
if err != nil {
return nil, err
}
return &run.Run, nil
}
func SearchRuns(conf Config, opts SearchRunRequest) (*SearchRunReply, error) {
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/runs/search", "application/json", opts.getReader())
rv := SearchRunReply{}
err = apiReadReply(resp, err, &rv)
if err != nil {
return nil, err
}
return &rv, nil
}
func UpdateRun(conf Config, opts UpdateRunRequest) (*RunInfo, error) {
args, err := json.Marshal(opts)
if err != nil {
return nil, err
}
rv := UpdateRunReply{}
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/runs/update", "application/json", bytes.NewReader([]byte(args)))
err = apiReadReply(resp, err, &rv)
if err != nil {
return nil, err
}
return &rv.RunInfo, nil
}
func LogMetric(conf Config, opts LogMetricRequest) error {
args, err := json.Marshal(opts.__init())
if err != nil {
return err
}
fmt.Println(opts)
fmt.Println(string(args))
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/runs/log-metric", "application/json", bytes.NewReader([]byte(args)))
err = apiReadReply(resp, err, nil)
if err != nil {
return err
}
return nil
}
func LogBatch(conf Config, opts LogBatchRequest) error {
for i, metric := range opts.Metrics {
opts.Metrics[i] = *metric.__init()
}
args, err := json.Marshal(opts)
if err != nil {
return err
}
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/runs/log-batch", "application/json", bytes.NewReader([]byte(args)))
err = apiReadReply(resp, err, nil)
if err != nil {
return err
}
return nil
}
func LogInputs(conf Config, opts LogInputsRequest) error {
args, err := json.Marshal(opts)
if err != nil {
return err
}
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/runs/log-inputs", "application/json", bytes.NewReader([]byte(args)))
err = apiReadReply(resp, err, nil)
if err != nil {
return err
}
return nil
}
func SetTag(conf Config, opts SetTagRequest) error {
args, err := json.Marshal(opts)
if err != nil {
return err
}
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/runs/set-tag", "application/json", bytes.NewReader([]byte(args)))
err = apiReadReply(resp, err, nil)
if err != nil {
return err
}
return nil
}
func DeleteTag(conf Config, opts DeleteTagRequest) error {
args, err := json.Marshal(opts)
if err != nil {
return err
}
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/runs/delete-tag", "application/json", bytes.NewReader([]byte(args)))
err = apiReadReply(resp, err, nil)
if err != nil {
return err
}
return nil
}
func LogParam(conf Config, opts LogParamRequest) error {
args, err := json.Marshal(opts)
if err != nil {
return err
}
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/runs/log-parameter", "application/json", bytes.NewReader([]byte(args)))
err = apiReadReply(resp, err, nil)
if err != nil {
return err
}
return nil
}
func GetMetricHistory(conf Config, opts GetMetricHistoryRequest) (*GetMetricHistoryResponse, error) {
rv := GetMetricHistoryResponse{}
resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/metrics/get-history" + UrlEncode(*opts.__init()))
err = apiReadReply(resp, err, &rv)
if err != nil {
return nil, err
}
return &rv, nil
}

186
run_types.go Normal file
View File

@ -0,0 +1,186 @@
package mlflow
import (
"bytes"
"encoding/json"
"time"
)
type RunTag struct {
Key string `json:"key,omitempty"`
Value string `json:"value,omitempty"`
}
type CreateRunRequest struct {
ExperimentId string `json:"experiment_id,omitempty"`
UserId string `json:"user_id,omitempty"`
RunName string `json:"run_name,omitempty"`
StartTime TimestampMs `json:"start_time,omitempty"`
Tags []RunTag `json:"tags,omitempty"`
}
type GetRunReply struct {
Run Run `json:"run,omitempty"`
}
type CreateRunReply struct {
GetRunReply
}
func (r *CreateRunRequest) serialize() []byte {
j, _ := json.Marshal(r)
return j
}
type RunStatus string
const RunStatusRunning = "RUNNING"
const RunStatusScheduled = "SCHEDULED"
const RunStatusFinished = "FINISHED"
const RunStatusFailed = "FAILED"
const RunStatusKilled = "KILLED"
type RunInfo struct {
RunId string `json:"run_id,omitempty"`
RunName string `json:"run_name,omitempty"`
ExperimentId string `json:"experiment_id,omitempty"`
UserId string `json:"user_id,omitempty"`
Status RunStatus `json:"status,omitempty"`
StartTime TimestampMs `json:"start_time,omitempty"`
EndTime TimestampMs `json:"end_time,omitempty"`
ArtifactUri string `json:"artifact_uri,omitempty"`
LifecycleStage string `json:"lifecycle_stage,omitempty"`
}
type Metric struct {
Key string `json:"key,omitempty"`
Value float64 `json:"value,omitempty"`
Timestamp TimestampMs `json:"timestamp,omitempty"`
Step int64 `json:"step,omitempty"`
}
func (r *Metric) __init() *Metric {
if r.Timestamp == 0 {
r.Timestamp = TimestampMs(time.Now().UnixMicro())
}
return r
}
type Param struct {
Key string `json:"key,omitempty"`
Value string `json:"value,omitempty"`
}
type RunData struct {
Metrics []Metric `json:"metics,omitempty"`
Params []Param `json:"params,omitempty"`
Tags []RunTag `json:"tags,omitempty"`
}
type InputTag struct {
Key string `json:"key,omitempty"`
Value string `json:"value,omitempty"`
}
type Dataset struct {
Name string `json:"name,omitempty"`
Digest string `json:"digest,omitempty"`
SourceType string `json:"source_type,omitempty"`
Source string `json:"source,omitempty"`
Schema string `json:"schema,omitempty"`
Profile string `json:"profile,omitempty"`
}
type DatasetInput struct {
Tags []InputTag `json:"tags,omitempty"`
Dataset Dataset `json:"dataset,omitempty"`
}
type RunInputs struct {
DatasetInputs []DatasetInput `json:"dataset_inputs,omitempty"`
}
type Run struct {
Info RunInfo `json:"info,omitempty"`
Data RunData `json:"data,omitempty"`
Inputs RunInputs `json:"inputs,omitempty"`
}
type SearchRunRequest struct {
ExperimentIds []string `json:"experiment_ids,omitempty"`
Filter string `json:"filter,omitempty"`
RunViewType ViewType `json:"run_view_type,omitempty"`
MaxResults int32 `json:"max_results,omitempty"`
OrgerBy []string `json:"order_by,omitempty"`
PageToken string `json:"page_token,omitempty"`
}
func (e *SearchRunRequest) serialize() []byte {
j, _ := json.Marshal(e)
return j
}
func (e *SearchRunRequest) __init() *SearchRunRequest {
if e.MaxResults == 0 {
e.MaxResults = 1000
}
return e
}
func (e *SearchRunRequest) getReader() *bytes.Reader {
return bytes.NewReader(e.__init().serialize())
}
type LogMetricRequest struct {
Metric
RunId string `json:"run_id,omitempty"`
}
func (fe *LogMetricRequest) __init() *LogMetricRequest { fe.Metric.__init(); return fe }
type LogBatchRequest struct {
RunId string `json:"run_id,omitempty"`
Metrics []Metric `json:"metrics,omitempty"`
Params []Param `json:"params,omitempty"`
Tags []RunTag `json:"tags,omitempty"`
}
type LogInputsRequest struct {
RunId string `json:"run_id,omitempty"`
Datasets []DatasetInput `json:"datasets,omitempty"`
}
type SetTagRequest struct {
RunTag
RunId string `json:"run_id,omitempty"`
}
type DeleteTagRequest struct {
RunTag
RunId string `json:"run_id,omitempty"`
}
type LogParamRequest struct {
Param
RunId string `json:"run_id,omitempty"`
}
type GetMetricHistoryResponse struct {
Metrics []Metric `json:"metrics,omitempty"`
NextPageToken string `json:"nex_page_token,omitempty"`
}
type GetMetricHistoryRequest struct {
RunId string `json:"run_id,omitempty"`
MetricKey string `json:"metric_key,omitempty"`
PageToken string `json:"page_token,omitempty"`
MaxResults int32 `json:"max_results,omitempty"`
}
func (r *GetMetricHistoryRequest) __init() *GetMetricHistoryRequest {
if r.MaxResults == 0 {
r.MaxResults = 1000
}
return r
}