diff --git a/server/search.go b/server/search.go index 99a8fcc..c11ebcf 100644 --- a/server/search.go +++ b/server/search.go @@ -15,6 +15,7 @@ import ( "golang.org/x/text/language" "log" "reflect" + "sort" "strings" ) @@ -25,9 +26,79 @@ type record struct { ClaimId string `json:"claim_id"` } +type compareFunc func(r1, r2 *record, invert bool) int + +type multiSorter struct { + records []record + compare []compareFunc + invert []bool +} + +var compareFuncs = map[string]compareFunc { + "height": func(r1, r2 *record, invert bool) int { + var res = 0 + if r1.Height < r2.Height { + res = -1 + } else if r1.Height > r2.Height { + res = 1 + } + if invert { + res = res * -1 + } + return res + }, +} + +// Sort sorts the argument slice according to the less functions passed to OrderedBy. +func (ms *multiSorter) Sort(records []record) { + ms.records = records + sort.Sort(ms) +} + +// OrderedBy returns a Sorter that sorts using the less functions, in order. +// Call its Sort method to sort the data. +func OrderedBy(compare ...compareFunc) *multiSorter { + return &multiSorter{ + compare: compare, + } +} + +// Len is part of sort.Interface. +func (ms *multiSorter) Len() int { + return len(ms.records) +} + +// Swap is part of sort.Interface. +func (ms *multiSorter) Swap(i, j int) { + ms.records[i], ms.records[j] = ms.records[j], ms.records[i] +} + +// Less is part of sort.Interface. It is implemented by looping along the +// less functions until it finds a comparison that discriminates between +// the two items (one is less than the other). Note that it can call the +// less functions twice per call. We could change the functions to return +// -1, 0, 1 and reduce the number of calls for greater efficiency: an +// exercise for the reader. +func (ms *multiSorter) Less(i, j int) bool { + p, q := &ms.records[i], &ms.records[j] + // Try all but the last comparison. + var k int + for k = 0; k < len(ms.compare)-1; k++ { + cmp := ms.compare[k] + res := cmp(p, q, ms.invert[k]) + + if res != 0 { + return res > 0 + } + } + // All comparisons to here said "equal", so just return whatever + // the final comparison reports. + return ms.compare[k](p, q, ms.invert[k]) > 0 +} + type orderField struct { Field string - is_asc bool + IsAsc bool } const ( errorResolution = iota @@ -263,80 +334,6 @@ func (s *Server) resolveUrl(ctx context.Context, rawUrl string) *urlResolution { } } } -/* - async def resolve_url(self, raw_url): - if raw_url not in self.resolution_cache: - self.resolution_cache[raw_url] = await self._resolve_url(raw_url) - return self.resolution_cache[raw_url] - - async def _resolve_url(self, raw_url): - try: - url = URL.parse(raw_url) - except ValueError as e: - return e - - stream = LookupError(f'Could not find claim at "{raw_url}".') - - channel_id = await self.resolve_channel_id(url) - if isinstance(channel_id, LookupError): - return channel_id - stream = (await self.resolve_stream(url, channel_id if isinstance(channel_id, str) else None)) or stream - if url.has_stream: - return StreamResolution(stream) - else: - return ChannelResolution(channel_id) - - async def resolve_channel_id(self, url: URL): - if not url.has_channel: - return - if url.channel.is_fullid: - return url.channel.claim_id - if url.channel.is_shortid: - channel_id = await self.full_id_from_short_id(url.channel.name, url.channel.claim_id) - if not channel_id: - return LookupError(f'Could not find channel in "{url}".') - return channel_id - - query = url.channel.to_dict() - if set(query) == {'name'}: - query['is_controlling'] = True - else: - query['order_by'] = ['^creation_height'] - matches, _, _ = await self.search(**query, limit=1) - if matches: - channel_id = matches[0]['claim_id'] - else: - return LookupError(f'Could not find channel in "{url}".') - return channel_id - - async def resolve_stream(self, url: URL, channel_id: str = None): - if not url.has_stream: - return None - if url.has_channel and channel_id is None: - return None - query = url.stream.to_dict() - if url.stream.claim_id is not None: - if url.stream.is_fullid: - claim_id = url.stream.claim_id - else: - claim_id = await self.full_id_from_short_id(query['name'], query['claim_id'], channel_id) - return claim_id - - if channel_id is not None: - if set(query) == {'name'}: - # temporarily emulate is_controlling for claims in channel - query['order_by'] = ['effective_amount', '^height'] - else: - query['order_by'] = ['^channel_join'] - query['channel_id'] = channel_id - query['signature_valid'] = True - elif set(query) == {'name'}: - query['is_controlling'] = True - matches, _, _ = await self.search(**query, limit=1) - if matches: - return matches[0]['claim_id'] - - */ func (s *Server) Search(ctx context.Context, in *pb.SearchRequest) (*pb.Outputs, error) { var client *elastic.Client = nil @@ -352,9 +349,6 @@ func (s *Server) Search(ctx context.Context, in *pb.SearchRequest) (*pb.Outputs, client = s.EsClient } - //res := s.resolveUrl(ctx, "@abc#111") - //log.Println(res) - claimTypes := map[string]int { "stream": 1, "channel": 2, @@ -402,6 +396,7 @@ func (s *Server) Search(ctx context.Context, in *pb.SearchRequest) (*pb.Outputs, var from = 0 var size = 10 var orderBy []orderField + var ms *multiSorter // Ping the Elasticsearch server to get e.g. the version number //_, code, err := client.Ping("http://127.0.0.1:9200").Do(ctx) @@ -445,9 +440,9 @@ func (s *Server) Search(ctx context.Context, in *pb.SearchRequest) (*pb.Outputs, if len(in.OrderBy) > 0 { for _, x := range in.OrderBy { var toAppend string - var is_asc = false + var isAsc = false if x[0] == '^' { - is_asc = true + isAsc = true x = x[1:] } if _, ok := replacements[x]; ok { @@ -459,7 +454,16 @@ func (s *Server) Search(ctx context.Context, in *pb.SearchRequest) (*pb.Outputs, if _, ok := textFields[toAppend]; ok { toAppend = toAppend + ".keyword" } - orderBy = append(orderBy, orderField{toAppend, is_asc}) + orderBy = append(orderBy, orderField{toAppend, isAsc}) + } + + ms = &multiSorter{ + invert: make([]bool, len(orderBy)), + compare: make([]compareFunc, len(orderBy)), + } + for i, x := range orderBy { + ms.compare[i] = compareFuncs[x.Field] + ms.invert[i] = x.IsAsc } } @@ -471,7 +475,6 @@ func (s *Server) Search(ctx context.Context, in *pb.SearchRequest) (*pb.Outputs, q = q.Must(elastic.NewTermsQuery("claim_type", searchVals...)) } - // FIXME is this a text field or not? if len(in.StreamType) > 0 { searchVals := make([]interface{}, len(in.StreamType)) for i := 0; i < len(in.StreamType); i++ { @@ -538,7 +541,14 @@ func (s *Server) Search(ctx context.Context, in *pb.SearchRequest) (*pb.Outputs, var collapse *elastic.CollapseBuilder if in.LimitClaimsPerChannel != nil { println(in.LimitClaimsPerChannel.Value) - innerHit := elastic.NewInnerHit().Size(int(in.LimitClaimsPerChannel.Value)).Name("channel_id.keyword") + innerHit := elastic. + NewInnerHit(). + //From(0). + Size(int(in.LimitClaimsPerChannel.Value)). + Name("channel_id") + for _, x := range orderBy { + innerHit = innerHit.Sort(x.Field, x.IsAsc) + } collapse = elastic.NewCollapseBuilder("channel_id.keyword").InnerHit(innerHit) } @@ -571,9 +581,7 @@ func (s *Server) Search(ctx context.Context, in *pb.SearchRequest) (*pb.Outputs, q = AddInvertibleField(in.ChannelId, "channel_id.keyword", q) q = AddInvertibleField(in.ChannelIds, "channel_id.keyword", q) - /* - */ q = AddRangeField(in.TxPosition, "tx_position", q) q = AddRangeField(in.Amount, "amount", q) @@ -629,6 +637,7 @@ func (s *Server) Search(ctx context.Context, in *pb.SearchRequest) (*pb.Outputs, } fsc := elastic.NewFetchSourceContext(true).Exclude("description", "title") + log.Printf("from: %d, size: %d\n", from, size) search := client.Search(). Index(searchIndices...). FetchSourceContext(fsc). @@ -638,8 +647,8 @@ func (s *Server) Search(ctx context.Context, in *pb.SearchRequest) (*pb.Outputs, search = search.Collapse(collapse) } for _, x := range orderBy { - log.Println(x.Field, x.is_asc) - search = search.Sort(x.Field, x.is_asc) + log.Println(x.Field, x.IsAsc) + search = search.Sort(x.Field, x.IsAsc) } searchResult, err := search.Do(ctx) // execute @@ -649,40 +658,75 @@ func (s *Server) Search(ctx context.Context, in *pb.SearchRequest) (*pb.Outputs, log.Printf("%s: found %d results in %dms\n", in.Text, len(searchResult.Hits.Hits), searchResult.TookInMillis) - txos := make([]*pb.Output, len(searchResult.Hits.Hits)) + var txos []*pb.Output - var r record - for i, item := range searchResult.Each(reflect.TypeOf(r)) { - if t, ok := item.(record); ok { - txos[i] = &pb.Output{ + if in.LimitClaimsPerChannel == nil { + txos = make([]*pb.Output, len(searchResult.Hits.Hits)) + + var r record + for i, item := range searchResult.Each(reflect.TypeOf(r)) { + if t, ok := item.(record); ok { + txos[i] = &pb.Output{ + TxHash: util.ToHash(t.Txid), + Nout: t.Nout, + Height: t.Height, + } + } + } + } else { + records := make([]record, 0, len(searchResult.Hits.Hits) * int(in.LimitClaimsPerChannel.Value)) + txos = make([]*pb.Output, 0, len(searchResult.Hits.Hits) * int(in.LimitClaimsPerChannel.Value)) + var i = 0 + for _, hit := range searchResult.Hits.Hits { + if innerHit, ok := hit.InnerHits["channel_id"]; ok { + for _, hitt := range innerHit.Hits.Hits { + if i >= size { + break + } + var t record + err := json.Unmarshal(hitt.Source, &t) + if err != nil { + return nil, err + } + records = append(records, t) + i++ + } + } + } + ms.Sort(records) + log.Println(records) + for _, t := range records { + res := &pb.Output{ TxHash: util.ToHash(t.Txid), Nout: t.Nout, Height: t.Height, } + txos = append(txos, res) } } // or if you want more control - for _, hit := range searchResult.Hits.Hits { - // hit.Index contains the name of the index - - var t map[string]interface{} // or could be a Record - err := json.Unmarshal(hit.Source, &t) - if err != nil { - return nil, err - } - - b, err := json.MarshalIndent(t, "", " ") - if err != nil { - fmt.Println("error:", err) - } - fmt.Println(string(b)) - //for k := range t { - // fmt.Println(k) - //} - //return nil, nil - } + //for _, hit := range searchResult.Hits.Hits { + // // hit.Index contains the name of the index + // + // var t map[string]interface{} // or could be a Record + // err := json.Unmarshal(hit.Source, &t) + // if err != nil { + // return nil, err + // } + // + // b, err := json.MarshalIndent(t, "", " ") + // if err != nil { + // fmt.Println("error:", err) + // } + // fmt.Println(string(b)) + // //for k := range t { + // // fmt.Println(k) + // //} + // //return nil, nil + //} + log.Printf("totalhits: %d\n", searchResult.TotalHits()) return &pb.Outputs{ Txos: txos, Total: uint32(searchResult.TotalHits()),