Added SearchModelVersionsAdvanced method

This commit is contained in:
Pavel Dmitriev 2024-05-16 15:27:15 +03:00
parent 1340962d12
commit f83eae6e58
2 changed files with 47 additions and 6 deletions

View File

@ -2,6 +2,7 @@ package mlflow
import ( import (
"bytes" "bytes"
"fmt"
"net/http" "net/http"
) )
@ -285,3 +286,35 @@ func GetModelVersionByAlias(conf Config, req GetModelVersionByAliasRequest) (*Mo
return &r.ModelVersion, nil return &r.ModelVersion, nil
} }
func SearchModelVersionsAdvanced(conf Config, req SearchModelVersionsRequestAdvanced) ([]ModelVersion, error) {
var models []ModelVersion = []ModelVersion{}
var j int = 0
for {
resp, err := SearchModelVersions(conf, req.SearchModelVersionsRequest)
if err != nil {
return nil, err
}
for _, item := range resp.ModelVersions {
if req.Stage == "" || req.Stage == item.CurrentStage {
models = append(models, item)
continue
}
}
if j > int(req.Limit) {
fmt.Println("Limit exceed, terminating")
break
}
if resp.NextPageToken == "" {
break
} else {
req.PageToken = resp.NextPageToken
}
j++
}
return models, nil
}

View File

@ -138,8 +138,9 @@ type GetDownloadUriForModelVersionArtifactsResponse struct {
} }
type SearchModelVersionsRequest struct { type SearchModelVersionsRequest struct {
// Valid keys are '{'version_number', 'source_path', 'run_id', 'name'}'
Filter string `json:"filter,omitempty"` Filter string `json:"filter,omitempty"`
MaxResults int64 `json:"max_results,omitempty"` PageSize int64 `json:"max_results,omitempty"`
OrderBy string `json:"order_by,omitempty"` OrderBy string `json:"order_by,omitempty"`
PageToken string `json:"page_token,omitempty"` PageToken string `json:"page_token,omitempty"`
} }
@ -150,8 +151,8 @@ func (e *SearchModelVersionsRequest) serialize() []byte {
} }
func (e *SearchModelVersionsRequest) __init() *SearchModelVersionsRequest { func (e *SearchModelVersionsRequest) __init() *SearchModelVersionsRequest {
if e.MaxResults == 0 { if e.PageSize == 0 {
e.MaxResults = 1000 e.PageSize = 1000
} }
return e return e
@ -202,6 +203,7 @@ func (e *TransitionModelVersionStageRequest) serialize() []byte {
type TransitionModelVersionStageResponse GetModelVersionResponse type TransitionModelVersionStageResponse GetModelVersionResponse
type SearchRegisteredModelsRequest struct { type SearchRegisteredModelsRequest struct {
// Valid keys are '{'name'}'
Filter string `json:"filter,omitempty"` Filter string `json:"filter,omitempty"`
MaxResults int64 `json:"max_results,omitempty"` MaxResults int64 `json:"max_results,omitempty"`
OrderBy string `json:"order_by,omitempty"` OrderBy string `json:"order_by,omitempty"`
@ -300,3 +302,9 @@ func (e *DeleteRegisteredModelAliasRequest) serialize() []byte {
} }
type GetModelVersionByAliasResponse GetModelVersionResponse type GetModelVersionByAliasResponse GetModelVersionResponse
type SearchModelVersionsRequestAdvanced struct {
SearchModelVersionsRequest
Stage string
Limit int64
}