package mlflow import ( "bytes" "fmt" "net/http" ) func CreateRegisteredModel(conf Config, req CreateRegisteredModelRequest) (*RegisteredModel, error) { resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/registered-models/create", "application/json", bytes.NewReader([]byte(req.serialize()))) r := CreateRegisteredModelResponse{} err = apiReadReply(resp, err, &r) if err != nil { return nil, err } return &r.RegisteredModel, nil } func GetRegisteredModel(conf Config, name string) (*RegisteredModel, error) { resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/registered-models/get?name=" + name) r := GetRegisteredModelResponse{} err = apiReadReply(resp, err, &r) if err != nil { return nil, err } return &r.RegisteredModel, nil } func RenameRegisteredModel(conf Config, req RenameRegisteredModelRequrest) (*RegisteredModel, error) { resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/registered-models/rename", "application/json", bytes.NewReader([]byte(req.serialize()))) r := RenameRegisteredModelResponse{} err = apiReadReply(resp, err, &r) if err != nil { return nil, err } return &r.RegisteredModel, nil } func DeleteRegisteredModel(conf Config, name string) error { client := &http.Client{} req, err := http.NewRequest("DELETE", conf.ApiURI+"/api/2.0/mlflow/registered-models/delete", bytes.NewReader([]byte(`{"name":"`+name+`"}`))) if err != nil { return err } req.Header.Add("ContentType", "application/json") resp, err := client.Do(req) err = apiReadReply(resp, err, nil) if err != nil { return err } return nil } func GetLatestModelVersions(conf Config, req GetLatestModelVersionsRequest) (*[]ModelVersion, error) { resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/registered-models/get-latest-versions", "application/json", bytes.NewReader([]byte(req.serialize()))) r := GetLatestModelVersionsReponse{} err = apiReadReply(resp, err, &r) if err != nil { return nil, err } return &r.ModelVersions, nil } func GetModelVersion(conf Config, req GetModelVersionRequest) (*ModelVersion, error) { resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/model-versions/get" + UrlEncode(req)) r := GetModelVersionResponse{} err = apiReadReply(resp, err, &r) if err != nil { return nil, err } return &r.ModelVersion, nil } func DeleteModelVersion(conf Config, req DeleteModelVersionRequest) error { client := &http.Client{} hreq, err := http.NewRequest("DELETE", conf.ApiURI+"/api/2.0/mlflow/model-versions/delete", bytes.NewReader([]byte(req.serialize()))) if err != nil { return err } hreq.Header.Add("ContentType", "application/json") resp, err := client.Do(hreq) err = apiReadReply(resp, err, nil) if err != nil { return err } return nil } func CreateModelVersion(conf Config, req CreateModelVersionRequest) (*ModelVersion, error) { resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/model-versions/create", "application/json", bytes.NewReader([]byte(req.serialize()))) r := CreateModelVersionResponse{} err = apiReadReply(resp, err, &r) if err != nil { return nil, err } return &r.ModelVersion, nil } func GetDownloadUriForModelVersionArtifacts(conf Config, req GetDownloadUriForModelVersionArtifactsRequest) (*string, error) { resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/model-versions/get-download-uri" + UrlEncode(req)) r := GetDownloadUriForModelVersionArtifactsResponse{} err = apiReadReply(resp, err, &r) if err != nil { return nil, err } return &r.ArtifactUri, nil } func UpdateModelVersion(conf Config, req UpdateModelVersionRequest) (*ModelVersion, error) { client := &http.Client{} hreq, err := http.NewRequest("PATCH", conf.ApiURI+"/api/2.0/mlflow/model-versions/update", bytes.NewReader([]byte(req.serialize()))) if err != nil { return nil, err } hreq.Header.Add("ContentType", "application/json") resp, err := client.Do(hreq) r := UpdateModelVersionResponse{} err = apiReadReply(resp, err, &r) if err != nil { return nil, err } return &r.ModelVersion, nil } func UpdateRegisteredModel(conf Config, req UpdateRegisteredModelRequest) (*RegisteredModel, error) { client := &http.Client{} hreq, err := http.NewRequest("PATCH", conf.ApiURI+"/api/2.0/mlflow/registered-models/update", bytes.NewReader([]byte(req.serialize()))) if err != nil { return nil, err } hreq.Header.Add("ContentType", "application/json") resp, err := client.Do(hreq) r := UpdateRegisteredModelResponse{} err = apiReadReply(resp, err, &r) if err != nil { return nil, err } return &r.RegisteredModel, nil } func TransitionModelVersionStage(conf Config, req TransitionModelVersionStageRequest) (*ModelVersion, error) { resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/model-versions/transition-stage", "application/json", bytes.NewReader([]byte(req.serialize()))) r := TransitionModelVersionStageResponse{} err = apiReadReply(resp, err, &r) if err != nil { return nil, err } return &r.ModelVersion, nil } func SearchModelVersions(conf Config, req SearchModelVersionsRequest) (*SearchModelVersionsResponse, error) { resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/model-versions/search" + UrlEncode(*req.__init())) r := SearchModelVersionsResponse{} err = apiReadReply(resp, err, &r) if err != nil { return nil, err } return &r, nil } func SearchRegisteredModels(conf Config, req SearchRegisteredModelsRequest) (*SearchRegisteredModelsResponse, error) { resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/registered-models/search" + UrlEncode(*req.__init())) r := SearchRegisteredModelsResponse{} err = apiReadReply(resp, err, &r) if err != nil { return nil, err } return &r, nil } func SetRegisteredModelTag(conf Config, req SetRegisteredModelTagRequest) error { resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/registered-models/set-tag", "application/json", bytes.NewReader([]byte(req.serialize()))) err = apiReadReply(resp, err, nil) if err != nil { return err } return nil } func SetModelVersionTag(conf Config, req SetModelVersionTagRequest) error { resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/model-versions/set-tag", "application/json", bytes.NewReader([]byte(req.serialize()))) err = apiReadReply(resp, err, nil) if err != nil { return err } return nil } func DeleteModelVersionTag(conf Config, req DeleteModelVersionTagRequest) error { client := &http.Client{} hreq, err := http.NewRequest("DELETE", conf.ApiURI+"/api/2.0/mlflow/model-versions/delete-tag", bytes.NewReader([]byte(req.serialize()))) if err != nil { return err } hreq.Header.Add("ContentType", "application/json") resp, err := client.Do(hreq) err = apiReadReply(resp, err, nil) if err != nil { return err } return nil } func DeleteRegisteredModelTag(conf Config, req DeleteRegisteredModelTagRequest) error { client := &http.Client{} hreq, err := http.NewRequest("DELETE", conf.ApiURI+"/api/2.0/mlflow/registered-models/delete-tag", bytes.NewReader([]byte(req.serialize()))) if err != nil { return err } hreq.Header.Add("ContentType", "application/json") resp, err := client.Do(hreq) err = apiReadReply(resp, err, nil) if err != nil { return err } return nil } func SetRegisteredModelAlias(conf Config, req SetRegisteredModelAliasRequest) error { resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/registered-models/alias", "application/json", bytes.NewReader([]byte(req.serialize()))) err = apiReadReply(resp, err, nil) if err != nil { return err } return nil } func DeleteRegisteredModelAlias(conf Config, req DeleteRegisteredModelAliasRequest) error { client := &http.Client{} hreq, err := http.NewRequest("DELETE", conf.ApiURI+"/api/2.0/mlflow/registered-models/alias", bytes.NewReader([]byte(req.serialize()))) if err != nil { return err } hreq.Header.Add("ContentType", "application/json") resp, err := client.Do(hreq) err = apiReadReply(resp, err, nil) if err != nil { return err } return nil } func GetModelVersionByAlias(conf Config, req GetModelVersionByAliasRequest) (*ModelVersion, error) { resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/registered-models/alias" + UrlEncode(req)) r := GetModelVersionByAliasResponse{} err = apiReadReply(resp, err, &r) if err != nil { return nil, err } return &r.ModelVersion, nil } func SearchModelVersionsAdvanced(conf Config, req SearchModelVersionsRequestAdvanced) ([]ModelVersion, error) { var models []ModelVersion = []ModelVersion{} var j int = 0 for { resp, err := SearchModelVersions(conf, req.SearchModelVersionsRequest) if err != nil { return nil, err } for _, item := range resp.ModelVersions { if req.Stage == "" || req.Stage == item.CurrentStage { models = append(models, item) continue } } if j > int(req.Limit) { fmt.Println("Limit exceed, terminating") break } if resp.NextPageToken == "" { break } else { req.PageToken = resp.NextPageToken } j++ } return models, nil }