commit 613fe39d8b73cdfd26b2527bd8af88ab506b62c7 Author: Pavel Dmitriev Date: Wed May 15 11:14:29 2024 +0300 initial commit diff --git a/api.go b/api.go new file mode 100644 index 0000000..55a9dc2 --- /dev/null +++ b/api.go @@ -0,0 +1,129 @@ +package mlflow + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "reflect" + "strings" + "time" +) + +const ViewTypeAll = "ALL" +const ViewTypeActive = "ACTIVE_ONLY" +const ViewTypeDeleted = "DELETED_ONLY" + +type ViewType string + +type MlflowApiError struct { + ErrorCode string `json:"error_code,omitempty"` + ErrorDesc string `json:"message,omitempty"` +} + +func (e MlflowApiError) Error() string { + return fmt.Sprint("Request Error " + string(e.ErrorCode) + ": " + e.ErrorDesc) +} + +type Timestamp int64 + +func (t *Timestamp) Time() time.Time { + return time.Unix(int64(*t), 0) +} + +func (t *Timestamp) Set(time time.Time) Timestamp { + return Timestamp(time.Unix()) +} + +type TimestampMs int64 + +func (t *TimestampMs) Time() time.Time { + return time.Unix(int64(*t/1000), 0) +} + +func (t *TimestampMs) Set(time time.Time) TimestampMs { + return TimestampMs(time.Unix() * 1000) +} + +func apiReadReply(resp *http.Response, err error, ref any) error { + if err != nil { + fmt.Println(err.Error()) + return err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + fmt.Println(err.Error()) + return err + } + + if resp.StatusCode != 200 && resp.StatusCode != 201 { + // Error code return + err := MlflowApiError{} + err2 := json.Unmarshal(body, &err) + if err2 != nil { + fmt.Println("Error when parsing reply: ", err2.Error(), "\n", string(body)) + return err2 + } + return err + } + + err = json.Unmarshal(body, &ref) + + //fmt.Println(string(body)) + + if err != nil { + return err + } + return nil +} + +type typeTag struct { + FieldName string + Omitempty bool +} + +func parseJsonTag(tag reflect.StructTag, FieldName string) typeTag { + rv := typeTag{} + tag1 := strings.Split(string(tag), ":") + + if tag1[0] == "json" { + tag2 := strings.Split(strings.ReplaceAll(tag1[1], `"`, ""), ",") + rv.FieldName = tag2[0] + rv.Omitempty = (len(tag2) == 2 && tag2[1] == "omitempty") + } + if rv.FieldName == "" { + rv.FieldName = FieldName + } + return rv +} + +func UrlEncode(s any) string { + rv := "" + rvCount := 0 + + v := reflect.ValueOf(s) + typeOfV := v.Type() + + values := make([]interface{}, v.NumField()) + + for i := 0; i < v.NumField(); i++ { + values[i] = v.Field(i).Interface() + name := typeOfV.Field(i).Name + tag := parseJsonTag(typeOfV.Field(i).Tag, name) + val := v.Field(i) + + if tag.FieldName != "-" && (!tag.Omitempty || fmt.Sprint(val) != "") { + if rvCount > 0 { + rv += "&" + } else { + rv += "?" + } + rv += fmt.Sprintf("%s=%s", tag.FieldName, fmt.Sprint(val)) + rvCount++ + } + + } + return rv +} diff --git a/artifacts.go b/artifacts.go new file mode 100644 index 0000000..949a580 --- /dev/null +++ b/artifacts.go @@ -0,0 +1,36 @@ +package mlflow + +import "net/http" + +type ListArtifactsRequest struct { + RunId string `json:"run_id,omitempty"` + Path string `json:"path,omitempty"` + PageToken string `json:"page_token,omitempty"` +} + +func (l *ListArtifactsRequest) UrlEncode() string { + return UrlEncode(*l) +} + +type FileInfo struct { + Path string `json:"path,omitempty"` + IsDir bool `json:"is_dir,omitempty"` + FileSize int64 `json:"file_size,omitempty"` +} + +type ListArtifactsResponse struct { + RootUri string `json:"root_uri,omitempty"` + Files []FileInfo `json:"files,omitempty"` + NextPageToken string `json:"next_page_token,omitempty"` +} + +func ListArtefacts(conf Config, req ListArtifactsRequest) (*ListArtifactsResponse, error) { + resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/artifacts/list" + req.UrlEncode()) + rv := ListArtifactsResponse{} + err = apiReadReply(resp, err, &rv) + if err != nil { + return nil, err + } + + return &rv, nil +} diff --git a/config.go b/config.go new file mode 100644 index 0000000..c30ae2f --- /dev/null +++ b/config.go @@ -0,0 +1,6 @@ +package mlflow + +type Config struct { + ApiURI string + IgnoreSSL bool +} diff --git a/experiment.go b/experiment.go new file mode 100644 index 0000000..ef0488d --- /dev/null +++ b/experiment.go @@ -0,0 +1,97 @@ +package mlflow + +import ( + "bytes" + "encoding/json" + "net/http" +) + +func SearchExperiments(conf Config, opts ExperimentSearchOptions) (*ExperimentSearchReply, error) { + resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/experiments/search", "application/json", opts.getReader()) + exp := ExperimentSearchReply{} + err = apiReadReply(resp, err, &exp) + if err != nil { + return nil, err + } + + return &exp, nil +} + +func GetExperiment(conf Config, id string) (*Experiment, error) { + resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/experiments/get" + "?experiment_id=" + id) + exp := ExperimentGetReply{} + err = apiReadReply(resp, err, &exp) + if err != nil { + return nil, err + } + + return &exp.Experiment, nil +} + +func GetExperimentByName(conf Config, name string) (*Experiment, error) { + resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/experiments/get-by-name" + "?experiment_name=" + name) + exp := ExperimentGetReply{} + err = apiReadReply(resp, err, &exp) + if err != nil { + return nil, err + } + + return &exp.Experiment, nil +} + +func DeleteExperiment(conf Config, id string) error { + resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/experiments/delete", "application/json", bytes.NewReader([]byte(`{"experiment_id":"`+id+`"}`))) + err = apiReadReply(resp, err, nil) + if err != nil { + return err + } + + return nil +} + +func RestoreExperiment(conf Config, id string) (*Experiment, error) { + resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/experiments/restore", "application/json", bytes.NewReader([]byte(`{"experiment_id":"`+id+`"}`))) + err = apiReadReply(resp, err, nil) + if err != nil { + return nil, err + } + + return GetExperiment(conf, id) +} + +func UpdateExperiment(conf Config, id string, newName string) (*Experiment, error) { + args := `{"experiment_id":"` + id + `","new_name":"` + newName + `"}` + resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/experiments/update", "application/json", bytes.NewReader([]byte(args))) + err = apiReadReply(resp, err, nil) + if err != nil { + return nil, err + } + + return GetExperiment(conf, id) +} + +func CreateExperiment(conf Config, exp *Experiment) (*Experiment, error) { + resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/experiments/create", "application/json", bytes.NewReader([]byte(exp.serialize()))) + xp := ExperimentCreateReply{} + err = apiReadReply(resp, err, &xp) + if err != nil { + return nil, err + } + + return GetExperiment(conf, xp.ExperimentId) +} + +func SetExperimentTagTag(conf Config, opts SetExperimentTagRequest) error { + args, err := json.Marshal(opts) + if err != nil { + return err + } + + resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/experiments/set-experiment-tag", "application/json", bytes.NewReader([]byte(args))) + err = apiReadReply(resp, err, nil) + if err != nil { + return err + } + + return nil +} diff --git a/experiment_types.go b/experiment_types.go new file mode 100644 index 0000000..3e18944 --- /dev/null +++ b/experiment_types.go @@ -0,0 +1,69 @@ +package mlflow + +import ( + "bytes" + "encoding/json" +) + +type ExperimentSearchReply struct { + Experiments []Experiment `json:"experiments"` + NextPageToken string `json:"next_page_token,omitempty"` +} + +type ExperimentGetReply struct { + Experiment Experiment `json:"experiment"` +} + +type ExperimentCreateReply struct { + ExperimentId string `json:"experiment_id"` +} + +type Experiment struct { + Id string `json:"experiment_id,omitempty"` + Name string `json:"name,omitempty"` + ArtifactLocation string `json:"artifact_location,omitempty"` + LifecycleStage string `json:"lifecycle_stage,omitempty"` + LastUpdate Timestamp `json:"last_update_time,omitempty"` + CreationTime Timestamp `json:"creation_time,omitempty"` + Tage []ExperimentTag `json:"tags,omitempty"` +} + +func (e *Experiment) serialize() []byte { + j, _ := json.Marshal(e) + return j +} + +type ExperimentTag struct { + Key string `json:"key,omitempty"` + Value string `json:"value,omitempty"` +} + +type ExperimentSearchOptions struct { + MaxResults int64 `json:"max_results,omitempty"` + Filter string `json:"filter,omitempty"` + PageToken string `json:"page_token,omitempty"` + OrderBy string `json:"order_by,omitempty"` + ViewType ViewType `json:"view_type,omitempty"` +} + +func (e *ExperimentSearchOptions) __init() *ExperimentSearchOptions { + if e.MaxResults == 0 { + e.MaxResults = 1000 + } + + return e +} + +func (e *ExperimentSearchOptions) serialize() []byte { + j, _ := json.Marshal(e) + return j +} + +func (e *ExperimentSearchOptions) getReader() *bytes.Reader { + return bytes.NewReader(e.__init().serialize()) +} + +type SetExperimentTagRequest struct { + ExperimentTag + ExperimentId string `json:"experiment_id,omitempty"` +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..e5ba0ec --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module mxfox.ru/mlops/mlflow-client-go + +go 1.21.1 diff --git a/model.go b/model.go new file mode 100644 index 0000000..fd9b169 --- /dev/null +++ b/model.go @@ -0,0 +1,287 @@ +package mlflow + +import ( + "bytes" + "net/http" +) + +func CreateRegisteredModel(conf Config, req CreateRegisteredModelRequest) (*RegisteredModel, error) { + resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/registered-models/create", "application/json", bytes.NewReader([]byte(req.serialize()))) + r := CreateRegisteredModelResponse{} + err = apiReadReply(resp, err, &r) + if err != nil { + return nil, err + } + + return &r.RegisteredModel, nil +} + +func GetRegisteredModel(conf Config, name string) (*RegisteredModel, error) { + resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/registered-models/get?name=" + name) + r := GetRegisteredModelResponse{} + err = apiReadReply(resp, err, &r) + if err != nil { + return nil, err + } + + return &r.RegisteredModel, nil +} + +func RenameRegisteredModel(conf Config, req RenameRegisteredModelRequrest) (*RegisteredModel, error) { + resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/registered-models/rename", "application/json", bytes.NewReader([]byte(req.serialize()))) + r := RenameRegisteredModelResponse{} + err = apiReadReply(resp, err, &r) + if err != nil { + return nil, err + } + + return &r.RegisteredModel, nil +} + +func DeleteRegisteredModel(conf Config, name string) error { + client := &http.Client{} + + req, err := http.NewRequest("DELETE", conf.ApiURI+"/api/2.0/mlflow/registered-models/delete", bytes.NewReader([]byte(`{"name":"`+name+`"}`))) + if err != nil { + return err + } + req.Header.Add("ContentType", "application/json") + resp, err := client.Do(req) + + err = apiReadReply(resp, err, nil) + if err != nil { + return err + } + + return nil +} + +func GetLatestModelVersions(conf Config, req GetLatestModelVersionsRequest) (*[]ModelVersion, error) { + resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/registered-models/get-latest-versions", "application/json", bytes.NewReader([]byte(req.serialize()))) + r := GetLatestModelVersionsReponse{} + err = apiReadReply(resp, err, &r) + if err != nil { + return nil, err + } + + return &r.ModelVersions, nil +} + +func GetModelVersion(conf Config, req GetModelVersionRequest) (*ModelVersion, error) { + resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/model-versions/get" + UrlEncode(req)) + r := GetModelVersionResponse{} + err = apiReadReply(resp, err, &r) + if err != nil { + return nil, err + } + + return &r.ModelVersion, nil +} + +func DeleteModelVersion(conf Config, req DeleteModelVersionRequest) error { + client := &http.Client{} + + hreq, err := http.NewRequest("DELETE", conf.ApiURI+"/api/2.0/mlflow/model-versions/delete", bytes.NewReader([]byte(req.serialize()))) + if err != nil { + return err + } + hreq.Header.Add("ContentType", "application/json") + resp, err := client.Do(hreq) + + err = apiReadReply(resp, err, nil) + if err != nil { + return err + } + + return nil +} + +func CreateModelVersion(conf Config, req CreateModelVersionRequest) (*ModelVersion, error) { + resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/model-versions/create", "application/json", bytes.NewReader([]byte(req.serialize()))) + r := CreateModelVersionResponse{} + err = apiReadReply(resp, err, &r) + if err != nil { + return nil, err + } + + return &r.ModelVersion, nil +} + +func GetDownloadUriForModelVersionArtifacts(conf Config, req GetDownloadUriForModelVersionArtifactsRequest) (*string, error) { + resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/model-versions/get-download-uri" + UrlEncode(req)) + r := GetDownloadUriForModelVersionArtifactsResponse{} + err = apiReadReply(resp, err, &r) + if err != nil { + return nil, err + } + + return &r.ArtifactUri, nil +} + +func UpdateModelVersion(conf Config, req UpdateModelVersionRequest) (*ModelVersion, error) { + client := &http.Client{} + + hreq, err := http.NewRequest("PATCH", conf.ApiURI+"/api/2.0/mlflow/model-versions/update", bytes.NewReader([]byte(req.serialize()))) + if err != nil { + return nil, err + } + hreq.Header.Add("ContentType", "application/json") + resp, err := client.Do(hreq) + + r := UpdateModelVersionResponse{} + + err = apiReadReply(resp, err, &r) + if err != nil { + return nil, err + } + + return &r.ModelVersion, nil +} + +func UpdateRegisteredModel(conf Config, req UpdateRegisteredModelRequest) (*RegisteredModel, error) { + client := &http.Client{} + + hreq, err := http.NewRequest("PATCH", conf.ApiURI+"/api/2.0/mlflow/registered-models/update", bytes.NewReader([]byte(req.serialize()))) + if err != nil { + return nil, err + } + hreq.Header.Add("ContentType", "application/json") + resp, err := client.Do(hreq) + + r := UpdateRegisteredModelResponse{} + + err = apiReadReply(resp, err, &r) + if err != nil { + return nil, err + } + + return &r.RegisteredModel, nil +} + +func TransitionModelVersionStage(conf Config, req TransitionModelVersionStageRequest) (*ModelVersion, error) { + resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/model-versions/transition-stage", "application/json", bytes.NewReader([]byte(req.serialize()))) + r := TransitionModelVersionStageResponse{} + err = apiReadReply(resp, err, &r) + if err != nil { + return nil, err + } + + return &r.ModelVersion, nil +} + +func SearchModelVersions(conf Config, req SearchModelVersionsRequest) (*SearchModelVersionsResponse, error) { + resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/model-versions/search" + UrlEncode(*req.__init())) + r := SearchModelVersionsResponse{} + err = apiReadReply(resp, err, &r) + if err != nil { + return nil, err + } + + return &r, nil +} + +func SearchRegisteredModels(conf Config, req SearchRegisteredModelsRequest) (*SearchRegisteredModelsResponse, error) { + resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/registered-models/search" + UrlEncode(*req.__init())) + r := SearchRegisteredModelsResponse{} + err = apiReadReply(resp, err, &r) + if err != nil { + return nil, err + } + + return &r, nil +} + +func SetRegisteredModelTag(conf Config, req SetRegisteredModelTagRequest) error { + resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/registered-models/set-tag", "application/json", bytes.NewReader([]byte(req.serialize()))) + err = apiReadReply(resp, err, nil) + if err != nil { + return err + } + + return nil +} + +func SetModelVersionTag(conf Config, req SetModelVersionTagRequest) error { + resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/model-versions/set-tag", "application/json", bytes.NewReader([]byte(req.serialize()))) + err = apiReadReply(resp, err, nil) + if err != nil { + return err + } + + return nil +} + +func DeleteModelVersionTag(conf Config, req DeleteModelVersionTagRequest) error { + client := &http.Client{} + + hreq, err := http.NewRequest("DELETE", conf.ApiURI+"/api/2.0/mlflow/model-versions/delete-tag", bytes.NewReader([]byte(req.serialize()))) + if err != nil { + return err + } + hreq.Header.Add("ContentType", "application/json") + resp, err := client.Do(hreq) + + err = apiReadReply(resp, err, nil) + if err != nil { + return err + } + + return nil +} + +func DeleteRegisteredModelTag(conf Config, req DeleteRegisteredModelTagRequest) error { + client := &http.Client{} + + hreq, err := http.NewRequest("DELETE", conf.ApiURI+"/api/2.0/mlflow/registered-models/delete-tag", bytes.NewReader([]byte(req.serialize()))) + if err != nil { + return err + } + hreq.Header.Add("ContentType", "application/json") + resp, err := client.Do(hreq) + + err = apiReadReply(resp, err, nil) + if err != nil { + return err + } + + return nil +} + +func SetRegisteredModelAlias(conf Config, req SetRegisteredModelAliasRequest) error { + resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/registered-models/alias", "application/json", bytes.NewReader([]byte(req.serialize()))) + err = apiReadReply(resp, err, nil) + if err != nil { + return err + } + + return nil +} + +func DeleteRegisteredModelAlias(conf Config, req DeleteRegisteredModelAliasRequest) error { + client := &http.Client{} + + hreq, err := http.NewRequest("DELETE", conf.ApiURI+"/api/2.0/mlflow/registered-models/alias", bytes.NewReader([]byte(req.serialize()))) + if err != nil { + return err + } + hreq.Header.Add("ContentType", "application/json") + resp, err := client.Do(hreq) + + err = apiReadReply(resp, err, nil) + if err != nil { + return err + } + + return nil +} + +func GetModelVersionByAlias(conf Config, req GetModelVersionByAliasRequest) (*ModelVersion, error) { + resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/registered-models/alias" + UrlEncode(req)) + r := GetModelVersionByAliasResponse{} + err = apiReadReply(resp, err, &r) + if err != nil { + return nil, err + } + + return &r.ModelVersion, nil +} diff --git a/model_types.go b/model_types.go new file mode 100644 index 0000000..e0eb956 --- /dev/null +++ b/model_types.go @@ -0,0 +1,302 @@ +package mlflow + +import "encoding/json" + +type ModelVersionStatus string + +const ModelVersionStatusPendingRegistration = "PENDING_REGISTRATION" +const ModelVersionStatusFailedRegistration = "FAILED_REGISTRATION" +const ModelVersionStatusReady = "READY" + +const ModelVersionStageNone = "None" +const ModelVersionStageStaging = "Staging" +const ModelVersionStageProduction = "Production" +const ModelVersionStageArchived = "Archived" + +type RegisteredModelTag struct { + Key string `json:"key,omitempty"` + Value string `json:"value,omitempty"` +} + +type ModelVersionTag struct { + Key string `json:"key,omitempty"` + Value string `json:"value,omitempty"` +} + +type RegisteredModelAlias struct { + Alias string `json:"alias,omitempty"` + Version string `json:"version,omitempty"` +} + +type CreateRegisteredModelRequest struct { + Name string `json:"name,omitempty"` + Tags []RegisteredModelTag `json:"tags,omitempty"` + Description string `json:"description,omitempty"` +} + +func (r *CreateRegisteredModelRequest) serialize() []byte { + rv, _ := json.Marshal(r) + return rv +} + +type CreateRegisteredModelResponse struct { + RegisteredModel RegisteredModel `json:"registered_model,omitempty"` +} + +type ModelVersion struct { + Name string `json:"name,omitempty"` + Version string `json:"version,omitempty"` + CreationTimestamp Timestamp `json:"creation_timestamp,omitempty"` + LastUpdateTimestamp Timestamp `json:"last_update_timestamp,omitempty"` + UserId string `json:"user_id,omitempty"` + CurrentStage string `json:"current_stage,omitempty"` + Description string `json:"description,omitempty"` + Source string `json:"source,omitempty"` + RunId string `json:"run_id,omitempty"` + Status ModelVersionStatus `json:"status,omitempty"` + StatusMessage string `json:"status_message,omitempty"` + Tags []ModelVersionTag `json:"tags,omitempty"` + RunLink string `json:"run_link,omitempty"` + Aliases []string `json:"aliases,omitempty"` +} + +type RegisteredModel struct { + Name string `json:"name,omitempty"` + CreationTimestamp Timestamp `json:"creation_timestamp,omitempty"` + LastUpdateTimestamp Timestamp `json:"last_update_timestamp,omitempty"` + UserId string `json:"user_id,omitempty"` + Description string `json:"description,omitempty"` + LatestVersions []ModelVersion `json:"latest_versions,omitempty"` + Tags []RegisteredModelTag `json:"tags,omitempty"` + Aliases []RegisteredModelAlias `json:"aliases,omitempty"` +} + +type GetRegisteredModelResponse CreateRegisteredModelResponse + +type RenameRegisteredModelRequrest struct { + Name string `json:"name,omitempty"` + NewName string `json:"new_name,omitempty"` +} + +func (r *RenameRegisteredModelRequrest) serialize() []byte { + rv, _ := json.Marshal(r) + return rv +} + +type RenameRegisteredModelResponse CreateRegisteredModelResponse + +type GetLatestModelVersionsRequest struct { + Name string `json:"name,omitempty"` + Stages []string `json:"stages,omitempty"` +} + +func (r *GetLatestModelVersionsRequest) serialize() []byte { + rv, _ := json.Marshal(r) + return rv +} + +type GetLatestModelVersionsReponse struct { + ModelVersions []ModelVersion `json:"model_versions,omitempty"` +} + +type GetModelVersionRequest struct { + Name string `json:"name,omitempty"` + Version string `json:"version,omitempty"` +} + +type GetModelVersionResponse struct { + ModelVersion ModelVersion `json:"model_version,omitempty"` +} + +type DeleteModelVersionRequest GetModelVersionRequest + +func (r *DeleteModelVersionRequest) serialize() []byte { + rv, _ := json.Marshal(r) + return rv +} + +type CreateModelVersionRequest struct { + Name string `json:"name,omitempty"` + Source string `json:"source,omitempty"` + RunId string `json:"run_id,omitempty"` + Tags []ModelVersionTag `json:"tags,omitempty"` + RunLink string `json:"run_link,omitempty"` + Description string `json:"description,omitempty"` +} + +type CreateModelVersionResponse GetModelVersionResponse + +func (r *CreateModelVersionRequest) serialize() []byte { + rv, _ := json.Marshal(r) + return rv +} + +type GetDownloadUriForModelVersionArtifactsRequest GetModelVersionRequest + +type GetDownloadUriForModelVersionArtifactsResponse struct { + ArtifactUri string `json:"artifact_uri,omitempty"` +} + +type SearchModelVersionsRequest struct { + Filter string + MaxResults int64 + OrderBy []string + PageToken string +} + +func (e *SearchModelVersionsRequest) serialize() []byte { + j, _ := json.Marshal(e) + return j +} + +func (e *SearchModelVersionsRequest) __init() *SearchModelVersionsRequest { + if e.MaxResults == 0 { + e.MaxResults = 1000 + } + + return e +} + +type SearchModelVersionsResponse struct { + ModelVersions []ModelVersion `json:"model_versions,omitempty"` + NextPageToken string `json:"next_page_token,omitempty"` +} + +type UpdateModelVersionRequest struct { + Name string `json:"name,omitempty"` + Version string `json:"version,omitempty"` + Description string `json:"description,omitempty"` +} + +type UpdateModelVersionResponse GetModelVersionResponse + +func (e *UpdateModelVersionRequest) serialize() []byte { + j, _ := json.Marshal(e) + return j +} + +type UpdateRegisteredModelRequest struct { + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` +} + +type UpdateRegisteredModelResponse GetRegisteredModelResponse + +func (e *UpdateRegisteredModelRequest) serialize() []byte { + j, _ := json.Marshal(e) + return j +} + +type TransitionModelVersionStageRequest struct { + Name string `json:"name,omitempty"` + Version string `json:"version,omitempty"` + Stage string `json:"stage,omitempty"` + ArchiveExistingVersions bool `json:"archive_existing_versions,omitempty"` +} + +func (e *TransitionModelVersionStageRequest) serialize() []byte { + j, _ := json.Marshal(e) + return j +} + +type TransitionModelVersionStageResponse GetModelVersionResponse + +type SearchRegisteredModelsRequest struct { + Filter string + MaxResults int64 + OrderBy []string + PageToken string +} + +func (e *SearchRegisteredModelsRequest) serialize() []byte { + j, _ := json.Marshal(e) + return j +} + +func (e *SearchRegisteredModelsRequest) __init() *SearchRegisteredModelsRequest { + if e.MaxResults == 0 { + e.MaxResults = 1000 + } + + return e +} + +type SearchRegisteredModelsResponse struct { + RegisteredModels []RegisteredModel `json:"registered_models,omitempty"` + NextPageToken string `json:"next_page_token,omitempty"` +} + +type SetRegisteredModelTagRequest struct { + RegisteredModelTag + Name string `json:"name,omitempty"` +} + +func (e *SetRegisteredModelTagRequest) serialize() []byte { + j, _ := json.Marshal(e) + return j +} + +type SetModelVersionTagRequest struct { + ModelVersionTag + Name string `json:"name,omitempty"` + Version string `json:"version,omitempty"` +} + +func (e *SetModelVersionTagRequest) serialize() []byte { + j, _ := json.Marshal(e) + return j +} + +type DeleteRegisteredModelTagRequest struct { + Key string `json:"key,omitempty"` + Name string `json:"name,omitempty"` +} + +func (e *DeleteRegisteredModelTagRequest) serialize() []byte { + j, _ := json.Marshal(e) + return j +} + +type DeleteModelVersionTagRequest struct { + Key string `json:"key,omitempty"` + Name string `json:"name,omitempty"` + Version string `json:"version,omitempty"` +} + +func (e *DeleteModelVersionTagRequest) serialize() []byte { + j, _ := json.Marshal(e) + return j +} + +type SetRegisteredModelAliasRequest struct { + Name string `json:"name,omitempty"` + Alias string `json:"alias,omitempty"` + Version string `json:"version,omitempty"` +} + +func (e *SetRegisteredModelAliasRequest) serialize() []byte { + j, _ := json.Marshal(e) + return j +} + +type GetModelVersionByAliasRequest struct { + Name string `json:"name,omitempty"` + Alias string `json:"alias,omitempty"` +} + +func (e *GetModelVersionByAliasRequest) serialize() []byte { + j, _ := json.Marshal(e) + return j +} + +type DeleteRegisteredModelAliasRequest struct { + Name string `json:"name,omitempty"` + Alias string `json:"alias,omitempty"` +} + +func (e *DeleteRegisteredModelAliasRequest) serialize() []byte { + j, _ := json.Marshal(e) + return j +} + +type GetModelVersionByAliasResponse GetModelVersionResponse diff --git a/run.go b/run.go new file mode 100644 index 0000000..d4a5795 --- /dev/null +++ b/run.go @@ -0,0 +1,201 @@ +package mlflow + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" +) + +type SearchRunReply struct { + Runs []Run `json:"runs,omitempty"` + NextPageToken string `json:"next_page_token,omitempty"` +} + +type UpdateRunRequest struct { + RunId string `json:"run_id,omitempty"` + Status RunStatus `json:"status,omitempty"` + EndTime TimestampMs `json:"end_time,omitempty"` + RunName string `json:"run_name,omitempty"` +} + +type UpdateRunReply struct { + RunInfo RunInfo `json:"run_info,omitempty"` +} + +func CreateRun(conf Config, req CreateRunRequest) (*Run, error) { + resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/runs/create", "application/json", bytes.NewReader([]byte(req.serialize()))) + r := CreateRunReply{} + err = apiReadReply(resp, err, &r) + if err != nil { + return nil, err + } + + return &r.Run, nil +} + +func DeleteRun(conf Config, id string) error { + resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/runs/delete", "application/json", bytes.NewReader([]byte(`{"run_id":"`+id+`"}`))) + err = apiReadReply(resp, err, nil) + if err != nil { + return err + } + + return nil +} + +func RestoreRun(conf Config, id string) (*Run, error) { + resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/runs/restore", "application/json", bytes.NewReader([]byte(`{"run_id":"`+id+`"}`))) + err = apiReadReply(resp, err, nil) + if err != nil { + return nil, err + } + + return GetRun(conf, id) +} + +func GetRun(conf Config, id string) (*Run, error) { + resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/runs/get" + "?run_id=" + id) + run := GetRunReply{} + err = apiReadReply(resp, err, &run) + if err != nil { + return nil, err + } + + return &run.Run, nil +} + +func SearchRuns(conf Config, opts SearchRunRequest) (*SearchRunReply, error) { + resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/runs/search", "application/json", opts.getReader()) + rv := SearchRunReply{} + err = apiReadReply(resp, err, &rv) + if err != nil { + return nil, err + } + + return &rv, nil +} + +func UpdateRun(conf Config, opts UpdateRunRequest) (*RunInfo, error) { + args, err := json.Marshal(opts) + if err != nil { + return nil, err + } + rv := UpdateRunReply{} + resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/runs/update", "application/json", bytes.NewReader([]byte(args))) + + err = apiReadReply(resp, err, &rv) + if err != nil { + return nil, err + } + + return &rv.RunInfo, nil +} + +func LogMetric(conf Config, opts LogMetricRequest) error { + args, err := json.Marshal(opts.__init()) + if err != nil { + return err + } + + fmt.Println(opts) + fmt.Println(string(args)) + + resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/runs/log-metric", "application/json", bytes.NewReader([]byte(args))) + err = apiReadReply(resp, err, nil) + if err != nil { + return err + } + + return nil +} + +func LogBatch(conf Config, opts LogBatchRequest) error { + for i, metric := range opts.Metrics { + opts.Metrics[i] = *metric.__init() + } + + args, err := json.Marshal(opts) + if err != nil { + return err + } + resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/runs/log-batch", "application/json", bytes.NewReader([]byte(args))) + err = apiReadReply(resp, err, nil) + if err != nil { + return err + } + + return nil +} + +func LogInputs(conf Config, opts LogInputsRequest) error { + args, err := json.Marshal(opts) + if err != nil { + return err + } + + resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/runs/log-inputs", "application/json", bytes.NewReader([]byte(args))) + err = apiReadReply(resp, err, nil) + if err != nil { + return err + } + + return nil +} + +func SetTag(conf Config, opts SetTagRequest) error { + args, err := json.Marshal(opts) + if err != nil { + return err + } + + resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/runs/set-tag", "application/json", bytes.NewReader([]byte(args))) + err = apiReadReply(resp, err, nil) + if err != nil { + return err + } + + return nil +} + +func DeleteTag(conf Config, opts DeleteTagRequest) error { + args, err := json.Marshal(opts) + if err != nil { + return err + } + + resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/runs/delete-tag", "application/json", bytes.NewReader([]byte(args))) + err = apiReadReply(resp, err, nil) + if err != nil { + return err + } + + return nil +} + +func LogParam(conf Config, opts LogParamRequest) error { + args, err := json.Marshal(opts) + if err != nil { + return err + } + + resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/runs/log-parameter", "application/json", bytes.NewReader([]byte(args))) + err = apiReadReply(resp, err, nil) + if err != nil { + return err + } + + return nil +} +func GetMetricHistory(conf Config, opts GetMetricHistoryRequest) (*GetMetricHistoryResponse, error) { + + rv := GetMetricHistoryResponse{} + resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/metrics/get-history" + UrlEncode(*opts.__init())) + + err = apiReadReply(resp, err, &rv) + if err != nil { + return nil, err + } + + return &rv, nil +} diff --git a/run_types.go b/run_types.go new file mode 100644 index 0000000..f6c65e7 --- /dev/null +++ b/run_types.go @@ -0,0 +1,186 @@ +package mlflow + +import ( + "bytes" + "encoding/json" + "time" +) + +type RunTag struct { + Key string `json:"key,omitempty"` + Value string `json:"value,omitempty"` +} + +type CreateRunRequest struct { + ExperimentId string `json:"experiment_id,omitempty"` + UserId string `json:"user_id,omitempty"` + RunName string `json:"run_name,omitempty"` + StartTime TimestampMs `json:"start_time,omitempty"` + Tags []RunTag `json:"tags,omitempty"` +} + +type GetRunReply struct { + Run Run `json:"run,omitempty"` +} + +type CreateRunReply struct { + GetRunReply +} + +func (r *CreateRunRequest) serialize() []byte { + j, _ := json.Marshal(r) + return j +} + +type RunStatus string + +const RunStatusRunning = "RUNNING" +const RunStatusScheduled = "SCHEDULED" +const RunStatusFinished = "FINISHED" +const RunStatusFailed = "FAILED" +const RunStatusKilled = "KILLED" + +type RunInfo struct { + RunId string `json:"run_id,omitempty"` + RunName string `json:"run_name,omitempty"` + ExperimentId string `json:"experiment_id,omitempty"` + UserId string `json:"user_id,omitempty"` + Status RunStatus `json:"status,omitempty"` + StartTime TimestampMs `json:"start_time,omitempty"` + EndTime TimestampMs `json:"end_time,omitempty"` + ArtifactUri string `json:"artifact_uri,omitempty"` + LifecycleStage string `json:"lifecycle_stage,omitempty"` +} + +type Metric struct { + Key string `json:"key,omitempty"` + Value float64 `json:"value,omitempty"` + Timestamp TimestampMs `json:"timestamp,omitempty"` + Step int64 `json:"step,omitempty"` +} + +func (r *Metric) __init() *Metric { + if r.Timestamp == 0 { + r.Timestamp = TimestampMs(time.Now().UnixMicro()) + } + return r +} + +type Param struct { + Key string `json:"key,omitempty"` + Value string `json:"value,omitempty"` +} + +type RunData struct { + Metrics []Metric `json:"metics,omitempty"` + Params []Param `json:"params,omitempty"` + Tags []RunTag `json:"tags,omitempty"` +} + +type InputTag struct { + Key string `json:"key,omitempty"` + Value string `json:"value,omitempty"` +} + +type Dataset struct { + Name string `json:"name,omitempty"` + Digest string `json:"digest,omitempty"` + SourceType string `json:"source_type,omitempty"` + Source string `json:"source,omitempty"` + Schema string `json:"schema,omitempty"` + Profile string `json:"profile,omitempty"` +} + +type DatasetInput struct { + Tags []InputTag `json:"tags,omitempty"` + Dataset Dataset `json:"dataset,omitempty"` +} + +type RunInputs struct { + DatasetInputs []DatasetInput `json:"dataset_inputs,omitempty"` +} + +type Run struct { + Info RunInfo `json:"info,omitempty"` + Data RunData `json:"data,omitempty"` + Inputs RunInputs `json:"inputs,omitempty"` +} + +type SearchRunRequest struct { + ExperimentIds []string `json:"experiment_ids,omitempty"` + Filter string `json:"filter,omitempty"` + RunViewType ViewType `json:"run_view_type,omitempty"` + MaxResults int32 `json:"max_results,omitempty"` + OrgerBy []string `json:"order_by,omitempty"` + PageToken string `json:"page_token,omitempty"` +} + +func (e *SearchRunRequest) serialize() []byte { + j, _ := json.Marshal(e) + return j +} + +func (e *SearchRunRequest) __init() *SearchRunRequest { + if e.MaxResults == 0 { + e.MaxResults = 1000 + } + + return e +} + +func (e *SearchRunRequest) getReader() *bytes.Reader { + return bytes.NewReader(e.__init().serialize()) +} + +type LogMetricRequest struct { + Metric + RunId string `json:"run_id,omitempty"` +} + +func (fe *LogMetricRequest) __init() *LogMetricRequest { fe.Metric.__init(); return fe } + +type LogBatchRequest struct { + RunId string `json:"run_id,omitempty"` + Metrics []Metric `json:"metrics,omitempty"` + Params []Param `json:"params,omitempty"` + Tags []RunTag `json:"tags,omitempty"` +} + +type LogInputsRequest struct { + RunId string `json:"run_id,omitempty"` + Datasets []DatasetInput `json:"datasets,omitempty"` +} + +type SetTagRequest struct { + RunTag + RunId string `json:"run_id,omitempty"` +} + +type DeleteTagRequest struct { + RunTag + RunId string `json:"run_id,omitempty"` +} + +type LogParamRequest struct { + Param + RunId string `json:"run_id,omitempty"` +} + +type GetMetricHistoryResponse struct { + Metrics []Metric `json:"metrics,omitempty"` + NextPageToken string `json:"nex_page_token,omitempty"` +} + +type GetMetricHistoryRequest struct { + RunId string `json:"run_id,omitempty"` + MetricKey string `json:"metric_key,omitempty"` + PageToken string `json:"page_token,omitempty"` + MaxResults int32 `json:"max_results,omitempty"` +} + +func (r *GetMetricHistoryRequest) __init() *GetMetricHistoryRequest { + if r.MaxResults == 0 { + r.MaxResults = 1000 + } + return r +}