迭代器模式与提供复杂数据结构查询的API

迭代器模式主要用在访问对象集合的场景,能够向客户端隐藏集合的实现细节

迭代器模式与提供复杂数据结构查询的API

有时会遇到这样的需求,开发一个模块,用于保存对象;不能用简单的数组、列表,得是红黑树、跳表等较为复杂的数据结构;有时为了提升存储效率或持久化,还得将对象序列化;但必须给客户端提供一个易用的 API,允许方便地、多种方式地遍历对象,丝毫不察觉背后的数据结构有多复杂

从描述可知,迭代器模式主要用在访问对象集合的场景,能够向客户端隐藏集合的实现细节

Java 的 Collection 家族、C++ 的 STL 标准库,都是使用迭代器模式的典范,它们为客户端提供了简单易用的 API,并且能够根据业务需要实现自己的迭代器,具备很好的可扩展性。

场景上下文

db 模块用来存储服务注册和监控信息,它的主要接口如下:

type Db interface {
    CreateTable(t *Table) error
    CreateTableIfNotExist(t *Table) error
    DeleteTable(tableName string) error

    Query(tableName string, primaryKey interface{}, result interface{}) error
    Insert(tableName string, primaryKey interface{}, record interface{}) error
    Update(tableName string, primaryKey interface{}, record interface{}) error
    Delete(tableName string, primaryKey interface{}) error
    
    ...
} 

从增删查改接口可以看出,它是一个 key-value 数据库,另外,为了提供类似关系型数据库的按列查询能力,我们又抽象出 Table 对象:

package db

// demo/db/table.go

import (
	"errors"
	//"fmt"
	"reflect"
	//"strings"
)

// Table 数据表定义
type Table struct {
	name       string
	recordType reflect.Type
	records    map[interface{}]record
}

func NewTable(name string) *Table {
	return &Table{
		name:    name,
		records: make(map[interface{}]record),
	}
}

func (t *Table) WithType(recordType reflect.Type) *Table {
	t.recordType = recordType
	return t
}
func (t *Table) Insert(key interface{}, value interface{}) error {

	if _, ok := t.records[key]; ok {
		return errors.New("ErrPrimaryKeyConflict")
	}
	record, err := recordFrom(key, value)
	if err != nil {
		return err
	}
	t.records[key] = record
	return nil

}

package db

import (
	"errors"
	"fmt"
	"reflect"
	"strings"
)

// 因为数据库的每个表都存储着不同对象
// 所以需要把类型存进去,根据类型创建自己需要的对象,再根据对象的属性,创建出表的每一列的属性
// 其中,Table 底层用 map 存储对象数据,但并没有存储对象本身,而是从对象转换而成的 record
type record struct {
	primaryKey interface{}
	fields     map[string]int
	values     []interface{}
}

//从对象转化为 record
func recordFrom(key interface{}, value interface{}) (r record, e error) {
	defer func() {
		if err := recover(); err != nil {
			r = record{}
			e = errors.New("ErrRecordTypeInvalid")
		}
	}()

	vType := reflect.TypeOf(value)
	fmt.Println("vType:", vType)
	vVal := reflect.ValueOf(value)
	fmt.Println("vVal:", vVal)

	if vVal.Type().Kind() == reflect.Ptr {
		vType = vType.Elem()
		vVal = vVal.Elem()
	}

	record := record{
		primaryKey: key,
		fields:     make(map[string]int, vVal.NumField()),
		values:     make([]interface{}, vVal.NumField()),
	}
	fmt.Println("vVal.NumField()", vVal.NumField())

	for i := 0; i < vVal.NumField(); i++ {
		fieldType := vType.Field(i)
		fieldVal := vVal.Field(i)
		name := strings.ToLower(fieldType.Name)
		record.fields[name] = i
		record.values[i] = fieldVal.Interface()
	}

	return record, nil

}

package db

import (
	"fmt"
	"reflect"
	"testing"
)

type testRegion struct {
	Id   int
	Name string
}

func TestTable(t *testing.T) {
	tableName := "testRegion"
	table := NewTable(tableName).WithType((reflect.TypeOf(new(testRegion))))
	table.Insert(2, &testRegion{Id: 2, Name: "beijing"})
	fmt.Println(table.records)
}

用迭代器实现

package db

// demo/db/table.go

import (
	"errors"
	//"fmt"
	"math/rand"
	"reflect"
	"time"
	//"strings"
)

// Table 数据表定义
type Table struct {
	name       string
	recordType reflect.Type
	records    map[interface{}]record
	// 关键点: 持有迭代器工厂方法接口
	iteratorFactory TableIteratorFactory
}

func NewTable(name string) *Table {
	return &Table{
		name:    name,
		records: make(map[interface{}]record),
	}
}

func (t *Table) WithType(recordType reflect.Type) *Table {
	t.recordType = recordType
	return t
}
func (t *Table) Insert(key interface{}, value interface{}) error {

	if _, ok := t.records[key]; ok {
		return errors.New("ErrPrimaryKeyConflict")
	}
	record, err := recordFrom(key, value)
	if err != nil {
		return err
	}
	t.records[key] = record
	return nil

}

// 关键点: 定义Setter方法,提供迭代器工厂的依赖注入
func (t *Table) WithTableIteratorFactory(iteratorFactory TableIteratorFactory) *Table {
	t.iteratorFactory = iteratorFactory
	return t
}

// 关键点: 定义创建迭代器的接口,其中调用迭代器工厂完成实例化
func (t *Table) Iterator() TableIterator {
	return t.iteratorFactory.Create(t)
}

type Next func(interface{}) error
type HasNext func() bool

func (t *Table) ClosureIterator() (HasNext, Next) {
	var records []record
	for _, r := range t.records {
		records = append(records, r)
	}
	rand.Seed(time.Now().UnixNano())
	rand.Shuffle(len(records), func(i, j int) {
		records[i], records[j] = records[j], records[i]
	})
	size := len(records)
	cursor := 0
	hasNext := func() bool {
		return cursor < size
	}
	next := func(next interface{}) error {
		record := records[cursor]
		cursor++
		if err := record.convertByValue(next); err != nil {
			return err
		}
		return nil
	}
	return hasNext, next
}

package db

// demo/db/iterator.go

import (
	"math/rand"
	"sort"
	"time"
)

type TableIterator interface {
	HasNext() bool
	Next(next interface{}) error
}

// 关键点: 定义迭代器接口的实现
// tableIteratorImpl 迭代器接口公共实现类 用来实现遍历表
type tableIteratorImpl struct {
	// 关键点3: 定义一个集合存储待遍历的记录,这里的记录已经排序好或者随机打散
	records []record
	// 关键点4: 定义一个cursor游标记录当前遍历的位置
	cursor int
}

// 关键点5: 在HasNext函数中的判断是否已经遍历完所有记录
func (r *tableIteratorImpl) HasNext() bool {
	return r.cursor < len(r.records)
}

// 关键点: 在Next函数中取出下一个记录,并转换成客户端期望的对象类型,记得增加cursor
func (r *tableIteratorImpl) Next(next interface{}) error {
	record := r.records[r.cursor]
	r.cursor++
	if err := record.convertByValue(next); err != nil {
		return err
	}
	return nil
}

type TableIteratorFactory interface {
	Create(table *Table) TableIterator
}

//创建迭代器的方式用工厂方法模式
//工厂可以创建出两种具体的迭代器 randomTableIteratorFactory sortedTableIteratorFactory
type randomTableIteratorFactory struct {
}

func NewRandomTableIteratorFactory() *randomTableIteratorFactory {
	return &randomTableIteratorFactory{}
}
func (r *randomTableIteratorFactory) Create(table *Table) TableIterator {
	var records []record
	for _, r := range table.records {
		records = append(records, r)
	}
	rand.Seed(time.Now().UnixNano())
	rand.Shuffle(len(records), func(i, j int) {
		records[i], records[j] = records[j], records[i]
	})
	return &tableIteratorImpl{
		records: records,
		cursor:  0,
	}
}

type sortedTableIteratorFactory struct {
	Comparator Comparator
	//comparator Comparator
}

func NewSortedTableIteratorFactory(c Comparator) *sortedTableIteratorFactory {
	return &sortedTableIteratorFactory{
		Comparator: c,
	}

}

func (s *sortedTableIteratorFactory) Create(table *Table) TableIterator {
	var res []record
	for _, r := range table.records {
		res = append(res, r)
	}
	/* re := &records{
		rs:         res,
		comparator: s.Comparator,
	} */
	sort.Sort(newrecords(res, s.Comparator))
	return &tableIteratorImpl{
		records: res,
		cursor:  0,
	}
}

type records struct {
	rs         []record
	comparator Comparator
}

func newrecords(res []record, com Comparator) *records {
	return &records{
		rs:         res,
		comparator: com,
	}

}

type Comparator func(i interface{}, j interface{}) bool

//Len()
func (r *records) Len() int {
	return len(r.rs)
}

//Less(): 成绩将有低到高排序
func (r *records) Less(i, j int) bool {
	return r.comparator(r.rs[i].primaryKey, r.rs[j].primaryKey)
}

//Swap()
func (r *records) Swap(i, j int) {
	tmp := r.rs[i]
	r.rs[i] = r.rs[j]
	r.rs[j] = tmp
}

// demo/db/iterator_test.go
package db

import (
	"fmt"
	"reflect"
	"testing"
)

func regionIdLess(i, j interface{}) bool {
	id1, ok := i.(int)
	if !ok {
		return false
	}
	id2, ok := j.(int)
	if !ok {
		return false
	}
	return id1 < id2
}

func TestRandomTableIterator(t *testing.T) {
	table := NewTable("testRegion").WithType(reflect.TypeOf(new(testRegion))).
		WithTableIteratorFactory(NewSortedTableIteratorFactory(regionIdLess))
	table.Insert(1, &testRegion{Id: 1, Name: "beijing"})
	table.Insert(2, &testRegion{Id: 2, Name: "shanghai"})
	table.Insert(3, &testRegion{Id: 3, Name: "guangdong"})
	//根据table得到 一个存储 排好顺序切片 和指向index的 结构体
	hasNext, next := table.ClosureIterator()

	for i := 0; i < 3; i++ {
		if !hasNext() {
			t.Error("records size error")
		}
		region := new(testRegion)
		if err := next(region); err != nil {
			t.Error(err)
		}
		fmt.Printf("%+v\n", region)
	}
	if hasNext() {
		t.Error("should not has next")
	}

}

func TestSortTableIterator(t *testing.T) {
	table := NewTable("testRegion").WithType(reflect.TypeOf(new(testRegion))).
		WithTableIteratorFactory(NewSortedTableIteratorFactory(regionIdLess))
	table.Insert(3, &testRegion{Id: 3, Name: "beijing"})
	table.Insert(1, &testRegion{Id: 1, Name: "shanghai"})
	table.Insert(2, &testRegion{Id: 2, Name: "guangdong"})
	iter := table.Iterator()
	region1 := new(testRegion)
	iter.Next(region1)
	if region1.Id != 1 {
		t.Error("region1 sort failed")
	}
	region2 := new(testRegion)
	iter.Next(region2)
	if region2.Id != 2 {
		t.Error("region2 sort failed")
	}
	region3 := new(testRegion)
	iter.Next(region3)
	if region3.Id != 3 {
		t.Error("region3 sort failed")
	}
}


// 需要用到的包
// demo/db/record.go
package db

import (
	"errors"
	//"fmt"
	"reflect"
	"strings"
)

// 因为数据库的每个表都存储着不同对象
// 所以需要把类型存进去,根据类型创建自己需要的对象,再根据对象的属性,创建出表的每一列的属性
// 其中,Table 底层用 map 存储对象数据,但并没有存储对象本身,而是从对象转换而成的 record
type record struct {
	primaryKey interface{}
	fields     map[string]int
	values     []interface{}
}

//从对象转化为 record
func recordFrom(key interface{}, value interface{}) (r record, e error) {
	defer func() {
		if err := recover(); err != nil {
			r = record{}
			e = errors.New("ErrRecordTypeInvalid")
		}
	}()

	vType := reflect.TypeOf(value)
	//fmt.Println("vType:", vType)
	vVal := reflect.ValueOf(value)
	//fmt.Println("vVal:", vVal)

	if vVal.Type().Kind() == reflect.Ptr {
		//fmt.Println("is ptr")
		vType = vType.Elem()
		//fmt.Println("vType:", vType)

		vVal = vVal.Elem()
		//fmt.Println("vVal:", vVal)

	}

	record := record{
		primaryKey: key,
		fields:     make(map[string]int, vVal.NumField()),
		values:     make([]interface{}, vVal.NumField()),
	}
	//fmt.Println("vVal.NumField()", vVal.NumField())

	for i := 0; i < vVal.NumField(); i++ {
		fieldType := vType.Field(i)
		//fmt.Println("fieldType :", fieldType)
		fieldVal := vVal.Field(i)
		//fmt.Println("fieldVal:", fieldVal)
		name := strings.ToLower(fieldType.Name)
		record.fields[name] = i
		record.values[i] = fieldVal.Interface()
	}

	return record, nil

}

func (r record) convertByValue(result interface{}) (e error) {
	defer func() {
		if err := recover(); err != nil {
			e = errors.New("ErrRecordTypeInvalid")
		}
	}()
	rType := reflect.TypeOf(result)
	rVal := reflect.ValueOf(result)
	if rType.Kind() == reflect.Ptr {
		rType = rType.Elem()
		rVal = rVal.Elem()
	}
	for i := 0; i < rType.NumField(); i++ {
		field := rVal.Field(i)
		field.Set(reflect.ValueOf(r.values[i]))
	}
	return nil
}

Tags: 设计模式
Share: X (Twitter) Facebook LinkedIn