added func (ref *ModelVersion) IsEqual(target *ModelVersion) bool
This commit is contained in:
parent
8de04ec577
commit
c7aa20f7f1
|
@ -1,6 +1,11 @@
|
|||
package mlflow
|
||||
|
||||
import "encoding/json"
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
)
|
||||
|
||||
type ModelVersionStatus string
|
||||
|
||||
|
@ -60,6 +65,78 @@ type ModelVersion struct {
|
|||
Aliases []string `json:"aliases,omitempty"`
|
||||
}
|
||||
|
||||
func (ref *ModelVersion) GetTag(key string) *ModelVersionTag {
|
||||
for _, tag := range ref.Tags {
|
||||
if tag.Key == key {
|
||||
return &tag
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ref *ModelVersion) IsEqual(target *ModelVersion) bool {
|
||||
vRef := reflect.ValueOf(*ref)
|
||||
vTarget := reflect.ValueOf(*target)
|
||||
|
||||
typeOfV := vRef.Type()
|
||||
values := make([]interface{}, vRef.NumField())
|
||||
for i := 0; i < vRef.NumField(); i++ {
|
||||
values[i] = vRef.Field(i).Interface()
|
||||
refName := typeOfV.Field(i).Name
|
||||
refType := typeOfV.Field(i).Type.String()
|
||||
refVal := vRef.Field(i)
|
||||
|
||||
targetVal := reflect.Indirect(vTarget).FieldByName(refName)
|
||||
|
||||
switch refType {
|
||||
case "string", "mlflow.ModelVersionStatus":
|
||||
if refVal.String() != targetVal.String() {
|
||||
fmt.Println("NE STRING: ", refType, refName, refVal, targetVal, (refVal.String() == targetVal.String()))
|
||||
return false
|
||||
}
|
||||
case "mlflow.TimestampMs":
|
||||
if refVal.Int() != targetVal.Int() {
|
||||
fmt.Println("NE TimestampMs: ", refType, refName, refVal, targetVal, (refVal.Int() == targetVal.Int()))
|
||||
return false
|
||||
}
|
||||
case "[]mlflow.ModelVersionTag":
|
||||
sliceRef, _ := refVal.Interface().([]ModelVersionTag)
|
||||
sliceTarget, _ := targetVal.Interface().([]ModelVersionTag)
|
||||
|
||||
if len(sliceRef) != len(sliceTarget) {
|
||||
//fmt.Println("Not equal - len mismatch")
|
||||
return false
|
||||
}
|
||||
|
||||
for _, val := range sliceRef {
|
||||
if tag2 := target.GetTag(val.Key); tag2 == nil || tag2.Value != val.Value {
|
||||
return false
|
||||
}
|
||||
}
|
||||
case "[]string":
|
||||
sliceRef, _ := refVal.Interface().([]string)
|
||||
sliceTarget, _ := targetVal.Interface().([]string)
|
||||
sort.Strings(sliceRef)
|
||||
sort.Strings(sliceTarget)
|
||||
if len(sliceRef) != len(sliceTarget) {
|
||||
//fmt.Println("Not equal - len mismatch")
|
||||
return false
|
||||
}
|
||||
|
||||
for xI, xRef := range sliceRef {
|
||||
if xRef != sliceTarget[xI] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
default:
|
||||
fmt.Println("Unknown: ", refType, refName, refVal, targetVal)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
type RegisteredModel struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
CreationTimestamp Timestamp `json:"creation_timestamp,omitempty"`
|
||||
|
|
Loading…
Reference in New Issue