2024-05-15 08:14:29 +00:00
|
|
|
package mlflow
|
|
|
|
|
|
|
|
import (
|
|
|
|
"bytes"
|
2024-05-16 12:27:15 +00:00
|
|
|
"fmt"
|
2024-05-15 08:14:29 +00:00
|
|
|
"net/http"
|
|
|
|
)
|
|
|
|
|
|
|
|
func CreateRegisteredModel(conf Config, req CreateRegisteredModelRequest) (*RegisteredModel, error) {
|
|
|
|
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/registered-models/create", "application/json", bytes.NewReader([]byte(req.serialize())))
|
|
|
|
r := CreateRegisteredModelResponse{}
|
|
|
|
err = apiReadReply(resp, err, &r)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
return &r.RegisteredModel, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func GetRegisteredModel(conf Config, name string) (*RegisteredModel, error) {
|
|
|
|
resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/registered-models/get?name=" + name)
|
|
|
|
r := GetRegisteredModelResponse{}
|
|
|
|
err = apiReadReply(resp, err, &r)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
return &r.RegisteredModel, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func RenameRegisteredModel(conf Config, req RenameRegisteredModelRequrest) (*RegisteredModel, error) {
|
|
|
|
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/registered-models/rename", "application/json", bytes.NewReader([]byte(req.serialize())))
|
|
|
|
r := RenameRegisteredModelResponse{}
|
|
|
|
err = apiReadReply(resp, err, &r)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
return &r.RegisteredModel, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func DeleteRegisteredModel(conf Config, name string) error {
|
|
|
|
client := &http.Client{}
|
|
|
|
|
|
|
|
req, err := http.NewRequest("DELETE", conf.ApiURI+"/api/2.0/mlflow/registered-models/delete", bytes.NewReader([]byte(`{"name":"`+name+`"}`)))
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
req.Header.Add("ContentType", "application/json")
|
|
|
|
resp, err := client.Do(req)
|
|
|
|
|
|
|
|
err = apiReadReply(resp, err, nil)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func GetLatestModelVersions(conf Config, req GetLatestModelVersionsRequest) (*[]ModelVersion, error) {
|
|
|
|
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/registered-models/get-latest-versions", "application/json", bytes.NewReader([]byte(req.serialize())))
|
|
|
|
r := GetLatestModelVersionsReponse{}
|
|
|
|
err = apiReadReply(resp, err, &r)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
return &r.ModelVersions, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func GetModelVersion(conf Config, req GetModelVersionRequest) (*ModelVersion, error) {
|
|
|
|
resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/model-versions/get" + UrlEncode(req))
|
|
|
|
r := GetModelVersionResponse{}
|
|
|
|
err = apiReadReply(resp, err, &r)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
return &r.ModelVersion, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func DeleteModelVersion(conf Config, req DeleteModelVersionRequest) error {
|
|
|
|
client := &http.Client{}
|
|
|
|
|
|
|
|
hreq, err := http.NewRequest("DELETE", conf.ApiURI+"/api/2.0/mlflow/model-versions/delete", bytes.NewReader([]byte(req.serialize())))
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
hreq.Header.Add("ContentType", "application/json")
|
|
|
|
resp, err := client.Do(hreq)
|
|
|
|
|
|
|
|
err = apiReadReply(resp, err, nil)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func CreateModelVersion(conf Config, req CreateModelVersionRequest) (*ModelVersion, error) {
|
|
|
|
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/model-versions/create", "application/json", bytes.NewReader([]byte(req.serialize())))
|
|
|
|
r := CreateModelVersionResponse{}
|
|
|
|
err = apiReadReply(resp, err, &r)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
return &r.ModelVersion, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func GetDownloadUriForModelVersionArtifacts(conf Config, req GetDownloadUriForModelVersionArtifactsRequest) (*string, error) {
|
|
|
|
resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/model-versions/get-download-uri" + UrlEncode(req))
|
|
|
|
r := GetDownloadUriForModelVersionArtifactsResponse{}
|
|
|
|
err = apiReadReply(resp, err, &r)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
return &r.ArtifactUri, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func UpdateModelVersion(conf Config, req UpdateModelVersionRequest) (*ModelVersion, error) {
|
|
|
|
client := &http.Client{}
|
|
|
|
|
|
|
|
hreq, err := http.NewRequest("PATCH", conf.ApiURI+"/api/2.0/mlflow/model-versions/update", bytes.NewReader([]byte(req.serialize())))
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
hreq.Header.Add("ContentType", "application/json")
|
|
|
|
resp, err := client.Do(hreq)
|
|
|
|
|
|
|
|
r := UpdateModelVersionResponse{}
|
|
|
|
|
|
|
|
err = apiReadReply(resp, err, &r)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
return &r.ModelVersion, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func UpdateRegisteredModel(conf Config, req UpdateRegisteredModelRequest) (*RegisteredModel, error) {
|
|
|
|
client := &http.Client{}
|
|
|
|
|
|
|
|
hreq, err := http.NewRequest("PATCH", conf.ApiURI+"/api/2.0/mlflow/registered-models/update", bytes.NewReader([]byte(req.serialize())))
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
hreq.Header.Add("ContentType", "application/json")
|
|
|
|
resp, err := client.Do(hreq)
|
|
|
|
|
|
|
|
r := UpdateRegisteredModelResponse{}
|
|
|
|
|
|
|
|
err = apiReadReply(resp, err, &r)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
return &r.RegisteredModel, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func TransitionModelVersionStage(conf Config, req TransitionModelVersionStageRequest) (*ModelVersion, error) {
|
|
|
|
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/model-versions/transition-stage", "application/json", bytes.NewReader([]byte(req.serialize())))
|
|
|
|
r := TransitionModelVersionStageResponse{}
|
|
|
|
err = apiReadReply(resp, err, &r)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
return &r.ModelVersion, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func SearchModelVersions(conf Config, req SearchModelVersionsRequest) (*SearchModelVersionsResponse, error) {
|
|
|
|
resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/model-versions/search" + UrlEncode(*req.__init()))
|
|
|
|
r := SearchModelVersionsResponse{}
|
|
|
|
err = apiReadReply(resp, err, &r)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
return &r, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func SearchRegisteredModels(conf Config, req SearchRegisteredModelsRequest) (*SearchRegisteredModelsResponse, error) {
|
|
|
|
resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/registered-models/search" + UrlEncode(*req.__init()))
|
|
|
|
r := SearchRegisteredModelsResponse{}
|
|
|
|
err = apiReadReply(resp, err, &r)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
return &r, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func SetRegisteredModelTag(conf Config, req SetRegisteredModelTagRequest) error {
|
|
|
|
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/registered-models/set-tag", "application/json", bytes.NewReader([]byte(req.serialize())))
|
|
|
|
err = apiReadReply(resp, err, nil)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func SetModelVersionTag(conf Config, req SetModelVersionTagRequest) error {
|
|
|
|
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/model-versions/set-tag", "application/json", bytes.NewReader([]byte(req.serialize())))
|
|
|
|
err = apiReadReply(resp, err, nil)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func DeleteModelVersionTag(conf Config, req DeleteModelVersionTagRequest) error {
|
|
|
|
client := &http.Client{}
|
|
|
|
|
|
|
|
hreq, err := http.NewRequest("DELETE", conf.ApiURI+"/api/2.0/mlflow/model-versions/delete-tag", bytes.NewReader([]byte(req.serialize())))
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
hreq.Header.Add("ContentType", "application/json")
|
|
|
|
resp, err := client.Do(hreq)
|
|
|
|
|
|
|
|
err = apiReadReply(resp, err, nil)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func DeleteRegisteredModelTag(conf Config, req DeleteRegisteredModelTagRequest) error {
|
|
|
|
client := &http.Client{}
|
|
|
|
|
|
|
|
hreq, err := http.NewRequest("DELETE", conf.ApiURI+"/api/2.0/mlflow/registered-models/delete-tag", bytes.NewReader([]byte(req.serialize())))
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
hreq.Header.Add("ContentType", "application/json")
|
|
|
|
resp, err := client.Do(hreq)
|
|
|
|
|
|
|
|
err = apiReadReply(resp, err, nil)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func SetRegisteredModelAlias(conf Config, req SetRegisteredModelAliasRequest) error {
|
|
|
|
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/registered-models/alias", "application/json", bytes.NewReader([]byte(req.serialize())))
|
|
|
|
err = apiReadReply(resp, err, nil)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func DeleteRegisteredModelAlias(conf Config, req DeleteRegisteredModelAliasRequest) error {
|
|
|
|
client := &http.Client{}
|
|
|
|
|
|
|
|
hreq, err := http.NewRequest("DELETE", conf.ApiURI+"/api/2.0/mlflow/registered-models/alias", bytes.NewReader([]byte(req.serialize())))
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
hreq.Header.Add("ContentType", "application/json")
|
|
|
|
resp, err := client.Do(hreq)
|
|
|
|
|
|
|
|
err = apiReadReply(resp, err, nil)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func GetModelVersionByAlias(conf Config, req GetModelVersionByAliasRequest) (*ModelVersion, error) {
|
|
|
|
resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/registered-models/alias" + UrlEncode(req))
|
|
|
|
r := GetModelVersionByAliasResponse{}
|
|
|
|
err = apiReadReply(resp, err, &r)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
return &r.ModelVersion, nil
|
|
|
|
}
|
2024-05-16 12:27:15 +00:00
|
|
|
|
|
|
|
func SearchModelVersionsAdvanced(conf Config, req SearchModelVersionsRequestAdvanced) ([]ModelVersion, error) {
|
|
|
|
|
|
|
|
var models []ModelVersion = []ModelVersion{}
|
|
|
|
var j int = 0
|
|
|
|
for {
|
|
|
|
|
|
|
|
resp, err := SearchModelVersions(conf, req.SearchModelVersionsRequest)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, item := range resp.ModelVersions {
|
|
|
|
if req.Stage == "" || req.Stage == item.CurrentStage {
|
|
|
|
models = append(models, item)
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if j > int(req.Limit) {
|
|
|
|
fmt.Println("Limit exceed, terminating")
|
|
|
|
break
|
|
|
|
}
|
|
|
|
if resp.NextPageToken == "" {
|
|
|
|
break
|
|
|
|
} else {
|
|
|
|
req.PageToken = resp.NextPageToken
|
|
|
|
}
|
|
|
|
j++
|
|
|
|
}
|
|
|
|
|
|
|
|
return models, nil
|
|
|
|
}
|