Elasticsearch plugin开发 之 自定义payload_score query

2021-02-16 11:24:21 浏览数 (1)

Elasticsearch plugin开发 之 自定义payload_score query

当需要将term的权重存储到索引中时,需要保存成payload的格式:

源代码:https://github.com/limingnihao/elasticsearch-reference/tree/master/Examples

官方文档:https://www.elastic.co/guide/en/elasticsearch/reference/7.10/analysis-delimited-payload-tokenfilter.html

类似于:

代码语言:txt复制
the|0 brown|3 fox|4 is|0 quick|10

查询的时候,如果需要用到保存好的value,则需要lucene 的PayloadScoreQuery或者PayloadCheckQuery。

PayloadScoreQuery:

首先查看下lucene的PayloadScoreQuery的构造方法:

代码语言:txt复制
  /**
   * Creates a new PayloadScoreQuery
   * @param wrappedQuery the query to wrap
   * @param function a PayloadFunction to use to modify the scores
   * @param decoder a PayloadDecoder to convert payloads into float values
   * @param includeSpanScore include both span score and payload score in the scoring algorithm
   */
  public PayloadScoreQuery(SpanQuery wrappedQuery, PayloadFunction function, PayloadDecoder decoder, boolean includeSpanScore) {
    this.wrappedQuery = Objects.requireNonNull(wrappedQuery);
    this.function = Objects.requireNonNull(function);
    this.decoder = Objects.requireNonNull(decoder);
    this.includeSpanScore = includeSpanScore;
  }

可以发现,需要构造4个参数:

  • SpanQuery wrappedQuery。进行召回的query,必须是spanQuery
  • PayloadFunction function。当命中多个term时,得分的计算规则,max、min、sum、
  • PayloadDecoder decoder。保存的value的解码方式。int或float类型
  • boolean includeSpanScore。是否使用保存的分数。

下面开始开发,需要构建2个类一个是plugin、一个是builder

PayloadScoreQParserPlugin

用于构造Builder的

代码语言:txt复制
public class PayloadScoreQParserPlugin extends Plugin implements SearchPlugin {

    @Override
    public List<QuerySpec<?>> getQueries() {
        return Collections.singletonList(
            new QuerySpec<>(PayloadScoreQueryBuilder.NAME, PayloadScoreQueryBuilder::new, PayloadScoreQueryBuilder::fromXContent)
        );
    }
}

PayloadScoreQueryBuilder

首先解析参数的fromXContent方法:

主要用于解析我们自定义的参数:query、func、calc(后续扩展权重交叉计算)、includeSpanScore

代码语言:txt复制
public static QueryBuilder fromXContent(XContentParser parser) throws IOException {
    String currentFieldName = null;
    XContentParser.Token token;
    QueryBuilder iqb = null;

    String func = null;
    String calc = null;
    boolean includeSpanScore = false;
    while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
        if (token == XContentParser.Token.FIELD_NAME) {
            currentFieldName = parser.currentName();
        } else if (token == XContentParser.Token.START_OBJECT) {
            if (QUERY_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                iqb = parseInnerQueryBuilder(parser);
            } else {
                throw new ParsingException(parser.getTokenLocation(),
                    "["   NAME   "] query does not support ["   currentFieldName   "]");
            }
        } else if (token.isValue()) {
            if (FUNC_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                func = parser.text();
            } else if (CALC_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                calc = parser.text();
            } else if (INCLUDE_SPAN_SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                includeSpanScore = parser.booleanValue();
            } else {
                throw new ParsingException(parser.getTokenLocation(),
                    "["   NAME   "] query does not support ["   currentFieldName   "]");
            }
        }
    }
    return new PayloadScoreQueryBuilder(iqb, func, calc, includeSpanScore);
}

构造PayloadScoreQuery的doToQuery方法:

主要是将lucene的PayloadScoreQuery类需要的4个参数构造出来:

代码语言:txt复制
protected Query doToQuery(SearchExecutionContext context) throws IOException {
    // query  parse
    SpanQuery spanQuery = null;
    try {
        spanQuery = (SpanQuery) query.toQuery(context);
    } catch (IOException e) {
        throw new IllegalArgumentException(e);
    }

    if (spanQuery == null) {
        throw new IllegalArgumentException("SpanQuery is null");
    }

    PayloadFunction payloadFunction = PayloadUtils.getPayloadFunction(this.func);
    if (payloadFunction == null) {
        throw new IllegalArgumentException("Unknown payload function: "   func);
    }
    PayloadDecoder payloadDecoder = PayloadUtils.getPayloadDecoder("float");

    return new PayloadScoreQuery(spanQuery, payloadFunction, payloadDecoder, this.includeSpanScore);
}

执行示例:

代码语言:txt复制
POST http://127.0.0.1:9200/position/_search
{
    "query": {
        "payload_score": {
            "func": "sum",
            "calc": "sum",
            "includeSpanScore": "false",
            "query": {
                "span_or": {
                    "clauses": [
                        {
                            "span_term": {
                                "FIELD": "test"
                            }
                        }
                    ]
                }
            }
        }
    }
}

0 人点赞