Make the distance comparator simpler to configure

This commit is contained in:
Dmytro Meleshko 2021-11-15 20:20:06 +02:00
parent 45adb7b7b3
commit 4ab258c3cc
4 changed files with 121 additions and 127 deletions

View file

@ -67,23 +67,17 @@ end
This source also provides a comparator function which uses information from the word indexer
to sort completion results based on the distance of the word from the cursor line. It will also
sort completion results coming from other sources, such as Language Servers, which might improve
accuracy of their suggestions too. The usage is, unfortunately, pretty hacky:
accuracy of their suggestions too. The usage is as follows:
```lua
local cmp = require('cmp')
local cmp_buffer = require('cmp_buffer').new()
-- You have to register and use a source with a different name from 'buffer'
-- because otherwise duplicate buffer sources will both be indexing the same
-- file twice, in parallel.
cmp.register_source('buffer_with_distance', cmp_buffer)
local cmp_buffer = require('cmp_buffer')
cmp.setup({
sources = {
{ name = 'buffer_with_distance' }, -- NOT 'buffer'!
{ name = 'buffer' },
-- The rest of your sources...
},
sorting = {
comparators = {
function(...) return cmp_buffer:compare_word_distance(...) end,

View file

@ -1,2 +1 @@
require'cmp'.register_source('buffer', require'cmp_buffer'.new())
require('cmp').register_source('buffer', require('cmp_buffer'))

View file

@ -1,116 +1 @@
local buffer = require('cmp_buffer.buffer')
---@class cmp_buffer.Options
---@field public keyword_length number
---@field public keyword_pattern string
---@field public get_bufnrs fun(): number[]
---@type cmp_buffer.Options
local defaults = {
keyword_length = 3,
keyword_pattern = [[\%(-\?\d\+\%(\.\d\+\)\?\|\h\w*\%([\-]\w*\)*\)]],
get_bufnrs = function()
return { vim.api.nvim_get_current_buf() }
end,
}
local source = {}
source.new = function()
local self = setmetatable({}, { __index = source })
self.buffers = {}
return self
end
---@return cmp_buffer.Options
source._validate_options = function(_, params)
local opts = vim.tbl_deep_extend('keep', params.option, defaults)
vim.validate({
keyword_length = { opts.keyword_length, 'number' },
keyword_pattern = { opts.keyword_pattern, 'string' },
get_bufnrs = { opts.get_bufnrs, 'function' },
})
return opts
end
source.get_keyword_pattern = function(self, params)
local opts = self:_validate_options(params)
return opts.keyword_pattern
end
source.complete = function(self, params, callback)
local opts = self:_validate_options(params)
local processing = false
local bufs = self:_get_buffers(opts)
for _, buf in ipairs(bufs) do
if buf.timer then
processing = true
break
end
end
vim.defer_fn(function()
local input = string.sub(params.context.cursor_before_line, params.offset)
local items = {}
local words = {}
for _, buf in ipairs(bufs) do
for _, word_list in ipairs(buf:get_words()) do
for word, _ in pairs(word_list) do
if not words[word] and input ~= word then
words[word] = true
table.insert(items, {
label = word,
dup = 0,
})
end
end
end
end
callback({
items = items,
isIncomplete = processing,
})
end, processing and 100 or 0)
end
---@param opts cmp_buffer.Options
source._get_buffers = function(self, opts)
local buffers = {}
for _, bufnr in ipairs(opts.get_bufnrs()) do
if not self.buffers[bufnr] then
local new_buf = buffer.new(bufnr, opts)
new_buf.on_close_cb = function()
self.buffers[bufnr] = nil
end
new_buf:index()
new_buf:watch()
self.buffers[bufnr] = new_buf
end
table.insert(buffers, self.buffers[bufnr])
end
return buffers
end
source._get_distance_from_entry = function(self, entry)
local buf = self.buffers[entry.context.bufnr]
if buf then
local distances = buf:get_words_distances(entry.context.cursor.line + 1)
return distances[entry.completion_item.filterText] or distances[entry.completion_item.label]
end
end
source.compare_word_distance = function(self, entry1, entry2)
if entry1.context ~= entry2.context then
return
end
local dist1 = self:_get_distance_from_entry(entry1) or math.huge
local dist2 = self:_get_distance_from_entry(entry2) or math.huge
if dist1 ~= dist2 then
return dist1 < dist2
end
end
return source
return require('cmp_buffer.source').new()

116
lua/cmp_buffer/source.lua Normal file
View file

@ -0,0 +1,116 @@
local buffer = require('cmp_buffer.buffer')
---@class cmp_buffer.Options
---@field public keyword_length number
---@field public keyword_pattern string
---@field public get_bufnrs fun(): number[]
---@type cmp_buffer.Options
local defaults = {
keyword_length = 3,
keyword_pattern = [[\%(-\?\d\+\%(\.\d\+\)\?\|\h\w*\%([\-]\w*\)*\)]],
get_bufnrs = function()
return { vim.api.nvim_get_current_buf() }
end,
}
local source = {}
source.new = function()
local self = setmetatable({}, { __index = source })
self.buffers = {}
return self
end
---@return cmp_buffer.Options
source._validate_options = function(_, params)
local opts = vim.tbl_deep_extend('keep', params.option, defaults)
vim.validate({
keyword_length = { opts.keyword_length, 'number' },
keyword_pattern = { opts.keyword_pattern, 'string' },
get_bufnrs = { opts.get_bufnrs, 'function' },
})
return opts
end
source.get_keyword_pattern = function(self, params)
local opts = self:_validate_options(params)
return opts.keyword_pattern
end
source.complete = function(self, params, callback)
local opts = self:_validate_options(params)
local processing = false
local bufs = self:_get_buffers(opts)
for _, buf in ipairs(bufs) do
if buf.timer then
processing = true
break
end
end
vim.defer_fn(function()
local input = string.sub(params.context.cursor_before_line, params.offset)
local items = {}
local words = {}
for _, buf in ipairs(bufs) do
for _, word_list in ipairs(buf:get_words()) do
for word, _ in pairs(word_list) do
if not words[word] and input ~= word then
words[word] = true
table.insert(items, {
label = word,
dup = 0,
})
end
end
end
end
callback({
items = items,
isIncomplete = processing,
})
end, processing and 100 or 0)
end
---@param opts cmp_buffer.Options
source._get_buffers = function(self, opts)
local buffers = {}
for _, bufnr in ipairs(opts.get_bufnrs()) do
if not self.buffers[bufnr] then
local new_buf = buffer.new(bufnr, opts)
new_buf.on_close_cb = function()
self.buffers[bufnr] = nil
end
new_buf:index()
new_buf:watch()
self.buffers[bufnr] = new_buf
end
table.insert(buffers, self.buffers[bufnr])
end
return buffers
end
source._get_distance_from_entry = function(self, entry)
local buf = self.buffers[entry.context.bufnr]
if buf then
local distances = buf:get_words_distances(entry.context.cursor.line + 1)
return distances[entry.completion_item.filterText] or distances[entry.completion_item.label]
end
end
source.compare_word_distance = function(self, entry1, entry2)
if entry1.context ~= entry2.context then
return
end
local dist1 = self:_get_distance_from_entry(entry1) or math.huge
local dist2 = self:_get_distance_from_entry(entry2) or math.huge
if dist1 ~= dist2 then
return dist1 < dist2
end
end
return source