diff --git a/lua/treesitter-context.lua b/lua/treesitter-context.lua index 9d8d5bc..d182881 100644 --- a/lua/treesitter-context.lua +++ b/lua/treesitter-context.lua @@ -1,6 +1,8 @@ local vim = vim local api = vim.api +local ts = vim.treesitter +local Highlighter = ts.highlighter local ts_utils = require'nvim-treesitter.ts_utils' local parsers = require'nvim-treesitter.parsers' @@ -8,27 +10,30 @@ local parsers = require'nvim-treesitter.parsers' local winid = nil local bufnr = api.nvim_create_buf(false, true) -local label = nil +local ns = api.nvim_create_namespace('nvim-treesitter-context') +local current_node = nil -- Helper functions -local get_line_for_node = function(node, type_patterns, transform_fn) +local is_valid = function(node, type_patterns) local node_type = node:type() - local is_valid = false for _, rgx in ipairs(type_patterns) do if node_type:find(rgx) then - is_valid = true - break + return true end end + return false +end - if not is_valid then return '' end +local get_text_for_node = function(node) + return ts_utils.get_node_text(node)[1] +end - -- local node_text = transform_fn(ts_utils.get_node_text(node)[1] or '') - local line = api.nvim_buf_get_lines(0, node:start(), node:start() + 1, false)[1] - - return line +local get_lines_for_node = function(node) + local start_row = node:start() + local end_row = node:end_() + return api.nvim_buf_get_lines(0, start_row, end_row + 1, false) end -- Trim spaces and opening brackets from end @@ -36,34 +41,6 @@ local transform_line = function(line) return line:gsub('%s*[%[%(%{]*%s*$', ''):gsub('\n', '') end -local get_context = function(opts) - if not parsers.has_parser() then return nil end - local options = opts or {} - local type_patterns = options.type_patterns or {'class', 'function', 'method'} - local transform_fn = options.transform_fn or transform_line - local separator = options.separator or ' -> ' - - local current_node = ts_utils.get_node_at_cursor() - if not current_node then return nil end - - local matches = {} - local expr = current_node - - while expr do - local line = get_line_for_node(expr, type_patterns, transform_fn) - if line ~= '' and not vim.tbl_contains(matches, line) then - table.insert(matches, 1, { expr:start(), line }) - end - expr = expr:parent() - end - - if #matches == 0 then - return nil - end - - return matches -end - local get_gutter_width = function() local old_col = api.nvim_call_function('col', { '.' }) api.nvim_call_function('cursor', { 0, 1 }) @@ -76,41 +53,71 @@ local nvim_augroup = function(group_name, definitions) api.nvim_command('augroup ' .. group_name) api.nvim_command('autocmd!') for _, def in ipairs(definitions) do - local command = table.concat(vim.tbl_flatten{'autocmd', def}, ' ') - api.nvim_command(command) + local command = table.concat({'autocmd', unpack(def)}, ' ') + if api.nvim_call_function('exists', {'#' .. def[1]}) then + api.nvim_command(command) + end end api.nvim_command('augroup END') end + -- Exports local M = {} +function M.get_context(opts) + if not parsers.has_parser() then return nil end + local options = opts or {} + local type_patterns = options.type_patterns or {'class', 'function', 'method'} + local transform_fn = options.transform_fn or transform_line + local separator = options.separator or ' -> ' + + local current_node = ts_utils.get_node_at_cursor() + if not current_node then return nil end + + local matches = {} + local expr = current_node + + while expr do + if is_valid(expr, type_patterns) then + table.insert(matches, 1, expr) + end + expr = expr:parent() + end + + if #matches == 0 then + return nil + end + + return matches +end + function M.update_context() if api.nvim_get_option('buftype') ~= '' then return end - local context = get_context() + local context = M.get_context() - label = nil + current_node = nil if context then local first_visible_line = api.nvim_call_function('line', { 'w0' }) for i = #context, 1, -1 do - local match = context[i] - local line_number, text = unpack(match) + local node = context[i] + local row = node:start() - if line_number < (first_visible_line - 1) then - label = text + if row < (first_visible_line - 1) then + current_node = node break end end end - if label then + if current_node then M.open() else M.close() @@ -132,6 +139,10 @@ function M.close() end function M.open() + local saved_bufnr = api.nvim_get_current_buf() + local start_row = current_node:start() + local end_row = current_node:end_() + if winid == nil or not api.nvim_win_is_valid(winid) then local gutter_width = get_gutter_width() local win_width = api.nvim_win_get_width(0) - gutter_width @@ -147,11 +158,50 @@ function M.open() }) -- else -- api.nvim_win_set_config(winid, { - -- width = #label, + -- width = #current_node, -- }) end - api.nvim_buf_set_lines(bufnr, 0, -1, false, { label }) + local start_row, start_col = current_node:start() + local lines = + start_col == 0 + and vim.split(get_text_for_node(current_node), '\n') + or get_lines_for_node(current_node) + local target_node = + start_col == 0 + and current_node + or current_node:parent() + + api.nvim_buf_clear_namespace(bufnr, ns, 0, -1) + api.nvim_buf_set_lines(bufnr, 0, -1, false, lines) + + local start_row_absolute = current_node:start() + + for _, highlighter in pairs(Highlighter.active[saved_bufnr] or {}) do + local iter = highlighter.query:iter_captures(target_node, saved_bufnr, start_row, end_row) + for capture, node in iter do + + local start_row, start_col, end_row, end_col = node:range() + local hl = highlighter.hl_cache[capture] + + if start_row >= start_row_absolute then + + start_row = start_row - start_row_absolute + end_row = end_row - start_row_absolute + + -- Sometimes there is an error :/ + -- but we ignore it :) + -- Yay? + local ok, err = pcall(function() + api.nvim_buf_set_extmark(bufnr, ns, start_row, start_col, + { end_line = end_row, end_col = end_col, + hl_group = hl, + -- ephemeral = true + }) + end) + end + end + end end function M.enable() diff --git a/static/demo.gif b/static/demo.gif index c77a6d7..9d427af 100644 Binary files a/static/demo.gif and b/static/demo.gif differ