/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.index.query;

import java.util.Locale;
import java.util.Map;
import lombok.Generated;
import lombok.NonNull;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.query.BaseQueryFactory;
import org.opensearch.knn.index.query.KNNQuery;
import org.opensearch.knn.index.query.RescoreKNNVectorQuery;
import org.opensearch.knn.index.query.common.QueryUtils;
import org.opensearch.knn.index.query.lucene.LuceneEngineKnnVectorQuery;
import org.opensearch.knn.index.query.lucenelib.NestedKnnVectorQueryFactory;
import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery;
import org.opensearch.knn.index.query.rescore.RescoreContext;

public class KNNQueryFactory
extends BaseQueryFactory {
    @Generated
    private static final Logger log = LogManager.getLogger(KNNQueryFactory.class);

    public static Query create(BaseQueryFactory.CreateQueryRequest createQueryRequest) {
        String indexName = createQueryRequest.getIndexName();
        String fieldName = createQueryRequest.getFieldName();
        int k = createQueryRequest.getK();
        float[] vector = createQueryRequest.getVector();
        float[] originalVector = createQueryRequest.getOriginalVector();
        byte[] byteVector = createQueryRequest.getByteVector();
        VectorDataType vectorDataType = createQueryRequest.getVectorDataType();
        Query filterQuery = KNNQueryFactory.getFilterQuery(createQueryRequest);
        Map<String, ?> methodParameters = createQueryRequest.getMethodParameters();
        RescoreContext rescoreContext = createQueryRequest.getRescoreContext().orElse(null);
        boolean expandNested = createQueryRequest.isExpandNested();
        boolean memoryOptimizedSearchEnabled = createQueryRequest.isMemoryOptimizedSearchEnabled();
        BitSetProducer parentFilter = null;
        int shardId = -1;
        if (createQueryRequest.getContext().isPresent()) {
            QueryShardContext context = createQueryRequest.getContext().get();
            parentFilter = context.getParentFilter();
            shardId = context.getShardId();
        }
        if (parentFilter == null && expandNested) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "Invalid value provided for the [%s] field. [%s] is only supported with a nested field.", "expand_nested_docs", "expand_nested_docs"));
        }
        if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) {
            Query validatedFilterQuery = KNNQueryFactory.validateFilterQuerySupport(filterQuery, createQueryRequest.getKnnEngine());
            log.debug("Creating custom k-NN query for index:{}, field:{}, k:{}, filterQuery:{}, efSearch:{}", (Object)indexName, (Object)fieldName, (Object)k, (Object)validatedFilterQuery, methodParameters);
            KNNQuery knnQuery = switch (vectorDataType) {
                case VectorDataType.BINARY -> KNNQuery.builder().field(fieldName).byteQueryVector(byteVector).indexName(indexName).parentsFilter(parentFilter).k(k).methodParameters(methodParameters).filterQuery(validatedFilterQuery).vectorDataType(vectorDataType).rescoreContext(rescoreContext).shardId(shardId).isMemoryOptimizedSearch(memoryOptimizedSearchEnabled).build();
                default -> KNNQuery.builder().field(fieldName).queryVector(vector).originalQueryVector(originalVector).byteQueryVector(byteVector).indexName(indexName).parentsFilter(parentFilter).k(k).methodParameters(methodParameters).filterQuery(validatedFilterQuery).vectorDataType(vectorDataType).rescoreContext(rescoreContext).shardId(shardId).isMemoryOptimizedSearch(memoryOptimizedSearchEnabled).build();
            };
            if (memoryOptimizedSearchEnabled || createQueryRequest.getRescoreContext().isPresent() || KNNEngine.ENGINES_SUPPORTING_NESTED_FIELDS.contains(createQueryRequest.getKnnEngine()) && expandNested) {
                return new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.getInstance(), expandNested);
            }
            return knnQuery;
        }
        Integer requestEfSearch = null;
        if (methodParameters != null && methodParameters.containsKey("ef_search")) {
            requestEfSearch = (Integer)methodParameters.get("ef_search");
        }
        int overSampledK = k;
        boolean needsRescore = KNNQueryFactory.shouldRescore(rescoreContext);
        if (needsRescore) {
            overSampledK = rescoreContext.getFirstPassK(k, false, KNNQueryFactory.getDimension(vector, byteVector));
        }
        int luceneK = requestEfSearch == null ? overSampledK : Math.max(overSampledK, requestEfSearch);
        log.debug("Creating Lucene k-NN query for index: {}, field:{}, k: {}", (Object)indexName, (Object)fieldName, (Object)luceneK);
        LuceneEngineKnnVectorQuery luceneKnnQuery = new LuceneEngineKnnVectorQuery(KNNQueryFactory.getKnnVectorQuery(fieldName, vector, byteVector, luceneK, filterQuery, parentFilter, expandNested, vectorDataType));
        return needsRescore ? new RescoreKNNVectorQuery(luceneKnnQuery, fieldName, k, vector, shardId) : luceneKnnQuery;
    }

    private static int getDimension(float[] floatQueryVector, byte[] byteQueryVector) {
        if (floatQueryVector != null) {
            return floatQueryVector.length;
        }
        if (byteQueryVector != null) {
            return byteQueryVector.length;
        }
        throw new IllegalStateException("QueryVector has neither float nor byte array");
    }

    private static Query validateFilterQuerySupport(Query filterQuery, KNNEngine knnEngine) {
        log.debug("filter query {}, knnEngine {}", (Object)filterQuery, (Object)knnEngine);
        if (filterQuery != null && KNNEngine.getEnginesThatSupportsFilters().contains(knnEngine)) {
            return filterQuery;
        }
        return null;
    }

    private static boolean shouldRescore(RescoreContext rescoreContext) {
        return rescoreContext != null && rescoreContext.isRescoreEnabled();
    }

    private static Query getKnnVectorQuery(String fieldName, float[] floatQueryVector, byte[] byteQueryVector, int k, Query filterQuery, BitSetProducer parentFilter, boolean expandNested, @NonNull VectorDataType vectorDataType) {
        if (vectorDataType == null) {
            throw new NullPointerException("vectorDataType is marked non-null but is null");
        }
        if (parentFilter == null) {
            assert (!expandNested) : "expandNested is allowed to be true only for nested fields.";
            return vectorDataType == VectorDataType.FLOAT ? new KnnFloatVectorQuery(fieldName, floatQueryVector, k, filterQuery) : new KnnByteVectorQuery(fieldName, byteQueryVector, k, filterQuery);
        }
        return vectorDataType == VectorDataType.FLOAT ? NestedKnnVectorQueryFactory.createNestedKnnVectorQuery(fieldName, floatQueryVector, k, filterQuery, parentFilter, expandNested) : NestedKnnVectorQueryFactory.createNestedKnnVectorQuery(fieldName, byteQueryVector, k, filterQuery, parentFilter, expandNested);
    }
}

