implement highlights for partial nodes

This commit is contained in:
Rom Grk 2020-10-20 19:10:00 -04:00
parent ee02ba1da7
commit 81474aff59
2 changed files with 99 additions and 49 deletions

View file

@ -1,6 +1,8 @@
local vim = vim
local api = vim.api
local ts = vim.treesitter
local Highlighter = ts.highlighter
local ts_utils = require'nvim-treesitter.ts_utils'
local parsers = require'nvim-treesitter.parsers'
@ -8,27 +10,30 @@ local parsers = require'nvim-treesitter.parsers'
local winid = nil
local bufnr = api.nvim_create_buf(false, true)
local label = nil
local ns = api.nvim_create_namespace('nvim-treesitter-context')
local current_node = nil
-- Helper functions
local get_line_for_node = function(node, type_patterns, transform_fn)
local is_valid = function(node, type_patterns)
local node_type = node:type()
local is_valid = false
for _, rgx in ipairs(type_patterns) do
if node_type:find(rgx) then
is_valid = true
break
return true
end
end
return false
end
if not is_valid then return '' end
local get_text_for_node = function(node)
return ts_utils.get_node_text(node)[1]
end
-- local node_text = transform_fn(ts_utils.get_node_text(node)[1] or '')
local line = api.nvim_buf_get_lines(0, node:start(), node:start() + 1, false)[1]
return line
local get_lines_for_node = function(node)
local start_row = node:start()
local end_row = node:end_()
return api.nvim_buf_get_lines(0, start_row, end_row + 1, false)
end
-- Trim spaces and opening brackets from end
@ -36,34 +41,6 @@ local transform_line = function(line)
return line:gsub('%s*[%[%(%{]*%s*$', ''):gsub('\n', '')
end
local get_context = function(opts)
if not parsers.has_parser() then return nil end
local options = opts or {}
local type_patterns = options.type_patterns or {'class', 'function', 'method'}
local transform_fn = options.transform_fn or transform_line
local separator = options.separator or ' -> '
local current_node = ts_utils.get_node_at_cursor()
if not current_node then return nil end
local matches = {}
local expr = current_node
while expr do
local line = get_line_for_node(expr, type_patterns, transform_fn)
if line ~= '' and not vim.tbl_contains(matches, line) then
table.insert(matches, 1, { expr:start(), line })
end
expr = expr:parent()
end
if #matches == 0 then
return nil
end
return matches
end
local get_gutter_width = function()
local old_col = api.nvim_call_function('col', { '.' })
api.nvim_call_function('cursor', { 0, 1 })
@ -76,41 +53,71 @@ local nvim_augroup = function(group_name, definitions)
api.nvim_command('augroup ' .. group_name)
api.nvim_command('autocmd!')
for _, def in ipairs(definitions) do
local command = table.concat(vim.tbl_flatten{'autocmd', def}, ' ')
api.nvim_command(command)
local command = table.concat({'autocmd', unpack(def)}, ' ')
if api.nvim_call_function('exists', {'#' .. def[1]}) then
api.nvim_command(command)
end
end
api.nvim_command('augroup END')
end
-- Exports
local M = {}
function M.get_context(opts)
if not parsers.has_parser() then return nil end
local options = opts or {}
local type_patterns = options.type_patterns or {'class', 'function', 'method'}
local transform_fn = options.transform_fn or transform_line
local separator = options.separator or ' -> '
local current_node = ts_utils.get_node_at_cursor()
if not current_node then return nil end
local matches = {}
local expr = current_node
while expr do
if is_valid(expr, type_patterns) then
table.insert(matches, 1, expr)
end
expr = expr:parent()
end
if #matches == 0 then
return nil
end
return matches
end
function M.update_context()
if api.nvim_get_option('buftype') ~= '' then
return
end
local context = get_context()
local context = M.get_context()
label = nil
current_node = nil
if context then
local first_visible_line = api.nvim_call_function('line', { 'w0' })
for i = #context, 1, -1 do
local match = context[i]
local line_number, text = unpack(match)
local node = context[i]
local row = node:start()
if line_number < (first_visible_line - 1) then
label = text
if row < (first_visible_line - 1) then
current_node = node
break
end
end
end
if label then
if current_node then
M.open()
else
M.close()
@ -132,6 +139,10 @@ function M.close()
end
function M.open()
local saved_bufnr = api.nvim_get_current_buf()
local start_row = current_node:start()
local end_row = current_node:end_()
if winid == nil or not api.nvim_win_is_valid(winid) then
local gutter_width = get_gutter_width()
local win_width = api.nvim_win_get_width(0) - gutter_width
@ -147,11 +158,50 @@ function M.open()
})
-- else
-- api.nvim_win_set_config(winid, {
-- width = #label,
-- width = #current_node,
-- })
end
api.nvim_buf_set_lines(bufnr, 0, -1, false, { label })
local start_row, start_col = current_node:start()
local lines =
start_col == 0
and vim.split(get_text_for_node(current_node), '\n')
or get_lines_for_node(current_node)
local target_node =
start_col == 0
and current_node
or current_node:parent()
api.nvim_buf_clear_namespace(bufnr, ns, 0, -1)
api.nvim_buf_set_lines(bufnr, 0, -1, false, lines)
local start_row_absolute = current_node:start()
for _, highlighter in pairs(Highlighter.active[saved_bufnr] or {}) do
local iter = highlighter.query:iter_captures(target_node, saved_bufnr, start_row, end_row)
for capture, node in iter do
local start_row, start_col, end_row, end_col = node:range()
local hl = highlighter.hl_cache[capture]
if start_row >= start_row_absolute then
start_row = start_row - start_row_absolute
end_row = end_row - start_row_absolute
-- Sometimes there is an error :/
-- but we ignore it :)
-- Yay?
local ok, err = pcall(function()
api.nvim_buf_set_extmark(bufnr, ns, start_row, start_col,
{ end_line = end_row, end_col = end_col,
hl_group = hl,
-- ephemeral = true
})
end)
end
end
end
end
function M.enable()

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.1 MiB

After

Width:  |  Height:  |  Size: 995 KiB