added func (ref *ModelVersion) IsEqual(target *ModelVersion) bool

This commit is contained in:
Pavel Dmitriev 2024-05-21 17:48:56 +03:00
parent 8de04ec577
commit c7aa20f7f1
1 changed files with 78 additions and 1 deletions

View File

@ -1,6 +1,11 @@
package mlflow package mlflow
import "encoding/json" import (
"encoding/json"
"fmt"
"reflect"
"sort"
)
type ModelVersionStatus string type ModelVersionStatus string
@ -60,6 +65,78 @@ type ModelVersion struct {
Aliases []string `json:"aliases,omitempty"` Aliases []string `json:"aliases,omitempty"`
} }
func (ref *ModelVersion) GetTag(key string) *ModelVersionTag {
for _, tag := range ref.Tags {
if tag.Key == key {
return &tag
}
}
return nil
}
func (ref *ModelVersion) IsEqual(target *ModelVersion) bool {
vRef := reflect.ValueOf(*ref)
vTarget := reflect.ValueOf(*target)
typeOfV := vRef.Type()
values := make([]interface{}, vRef.NumField())
for i := 0; i < vRef.NumField(); i++ {
values[i] = vRef.Field(i).Interface()
refName := typeOfV.Field(i).Name
refType := typeOfV.Field(i).Type.String()
refVal := vRef.Field(i)
targetVal := reflect.Indirect(vTarget).FieldByName(refName)
switch refType {
case "string", "mlflow.ModelVersionStatus":
if refVal.String() != targetVal.String() {
fmt.Println("NE STRING: ", refType, refName, refVal, targetVal, (refVal.String() == targetVal.String()))
return false
}
case "mlflow.TimestampMs":
if refVal.Int() != targetVal.Int() {
fmt.Println("NE TimestampMs: ", refType, refName, refVal, targetVal, (refVal.Int() == targetVal.Int()))
return false
}
case "[]mlflow.ModelVersionTag":
sliceRef, _ := refVal.Interface().([]ModelVersionTag)
sliceTarget, _ := targetVal.Interface().([]ModelVersionTag)
if len(sliceRef) != len(sliceTarget) {
//fmt.Println("Not equal - len mismatch")
return false
}
for _, val := range sliceRef {
if tag2 := target.GetTag(val.Key); tag2 == nil || tag2.Value != val.Value {
return false
}
}
case "[]string":
sliceRef, _ := refVal.Interface().([]string)
sliceTarget, _ := targetVal.Interface().([]string)
sort.Strings(sliceRef)
sort.Strings(sliceTarget)
if len(sliceRef) != len(sliceTarget) {
//fmt.Println("Not equal - len mismatch")
return false
}
for xI, xRef := range sliceRef {
if xRef != sliceTarget[xI] {
return false
}
}
default:
fmt.Println("Unknown: ", refType, refName, refVal, targetVal)
}
}
return true
}
type RegisteredModel struct { type RegisteredModel struct {
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
CreationTimestamp Timestamp `json:"creation_timestamp,omitempty"` CreationTimestamp Timestamp `json:"creation_timestamp,omitempty"`