diff --git a/model.go b/model.go index fd9b169..4c74b96 100644 --- a/model.go +++ b/model.go @@ -2,6 +2,7 @@ package mlflow import ( "bytes" + "fmt" "net/http" ) @@ -285,3 +286,35 @@ func GetModelVersionByAlias(conf Config, req GetModelVersionByAliasRequest) (*Mo return &r.ModelVersion, nil } + +func SearchModelVersionsAdvanced(conf Config, req SearchModelVersionsRequestAdvanced) ([]ModelVersion, error) { + + var models []ModelVersion = []ModelVersion{} + var j int = 0 + for { + + resp, err := SearchModelVersions(conf, req.SearchModelVersionsRequest) + if err != nil { + return nil, err + } + + for _, item := range resp.ModelVersions { + if req.Stage == "" || req.Stage == item.CurrentStage { + models = append(models, item) + continue + } + } + if j > int(req.Limit) { + fmt.Println("Limit exceed, terminating") + break + } + if resp.NextPageToken == "" { + break + } else { + req.PageToken = resp.NextPageToken + } + j++ + } + + return models, nil +} diff --git a/model_types.go b/model_types.go index 3ef145b..e1e2685 100644 --- a/model_types.go +++ b/model_types.go @@ -138,10 +138,11 @@ type GetDownloadUriForModelVersionArtifactsResponse struct { } type SearchModelVersionsRequest struct { - Filter string `json:"filter,omitempty"` - MaxResults int64 `json:"max_results,omitempty"` - OrderBy string `json:"order_by,omitempty"` - PageToken string `json:"page_token,omitempty"` + // Valid keys are '{'version_number', 'source_path', 'run_id', 'name'}' + Filter string `json:"filter,omitempty"` + PageSize int64 `json:"max_results,omitempty"` + OrderBy string `json:"order_by,omitempty"` + PageToken string `json:"page_token,omitempty"` } func (e *SearchModelVersionsRequest) serialize() []byte { @@ -150,8 +151,8 @@ func (e *SearchModelVersionsRequest) serialize() []byte { } func (e *SearchModelVersionsRequest) __init() *SearchModelVersionsRequest { - if e.MaxResults == 0 { - e.MaxResults = 1000 + if e.PageSize == 0 { + e.PageSize = 1000 } return e @@ -202,6 +203,7 @@ func (e *TransitionModelVersionStageRequest) serialize() []byte { type TransitionModelVersionStageResponse GetModelVersionResponse type SearchRegisteredModelsRequest struct { + // Valid keys are '{'name'}' Filter string `json:"filter,omitempty"` MaxResults int64 `json:"max_results,omitempty"` OrderBy string `json:"order_by,omitempty"` @@ -300,3 +302,9 @@ func (e *DeleteRegisteredModelAliasRequest) serialize() []byte { } type GetModelVersionByAliasResponse GetModelVersionResponse + +type SearchModelVersionsRequestAdvanced struct { + SearchModelVersionsRequest + Stage string + Limit int64 +}