chore: formatting

This commit is contained in:
Lewis Russell 2023-07-13 09:53:14 +01:00
parent 4d497f8e56
commit c315c973be
2 changed files with 126 additions and 103 deletions

6
.stylua.toml Normal file
View file

@ -0,0 +1,6 @@
column_width = 100
line_endings = "Unix"
indent_type = "Spaces"
indent_width = 2
quote_style = "AutoPreferSingle"
call_parentheses = "Always"

View file

@ -1,22 +1,17 @@
local api = vim.api local fn, api = vim.fn, vim.api
local highlighter = vim.treesitter.highlighter local highlighter = vim.treesitter.highlighter
local cache = require'treesitter-context.cache' local cache = require('treesitter-context.cache')
local get_lang = local get_lang = vim.treesitter.language.get_lang or require('nvim-treesitter.parsers').ft_to_lang
vim.treesitter.language.get_lang
or require'nvim-treesitter.parsers'.ft_to_lang
local get_query = --- @diagnostic disable-next-line:deprecated
vim.treesitter.query.get or local get_query = vim.treesitter.query.get or vim.treesitter.query.get_query
vim.treesitter.query.get_query
local augroup = api.nvim_create_augroup local augroup = api.nvim_create_augroup
local command = api.nvim_create_user_command local command = api.nvim_create_user_command
---@diagnostic disable:invisible --- @class TSContext.Config
--- @class Config
--- @field enable boolean --- @field enable boolean
--- @field max_lines integer --- @field max_lines integer
--- @field min_window_height integer --- @field min_window_height integer
@ -28,7 +23,7 @@ local command = api.nvim_create_user_command
--- @field separator? string --- @field separator? string
--- @field on_attach? fun(buf: integer): boolean --- @field on_attach? fun(buf: integer): boolean
--- @type Config --- @type TSContext.Config
local defaultConfig = { local defaultConfig = {
enable = true, enable = true,
max_lines = 0, -- no limit max_lines = 0, -- no limit
@ -40,7 +35,7 @@ local defaultConfig = {
mode = 'cursor', mode = 'cursor',
} }
--- @type Config --- @type TSContext.Config
local config = {} local config = {}
-- Constants -- Constants
@ -53,17 +48,11 @@ local did_setup = false
local enabled = false local enabled = false
-- Don't access directly, use get_bufs() -- Don't access directly, use get_bufs()
--- @type integer? local gutter_bufnr --- @type integer?
local gutter_winid local context_bufnr --- @type integer?
--- @type integer? local gutter_winid --- @type integer?
local context_winid local context_winid --- @type integer?
--- @type integer?
local gutter_bufnr
--- @type integer?
local context_bufnr
local ns = api.nvim_create_namespace('nvim-treesitter-context') local ns = api.nvim_create_namespace('nvim-treesitter-context')
@ -75,7 +64,6 @@ local all_contexts = {}
--- @return TSNode --- @return TSNode
local function get_root_node() local function get_root_node()
---@diagnostic disable-next-line
local tree = vim.treesitter.get_parser():parse()[1] local tree = vim.treesitter.get_parser():parse()[1]
return tree:root() return tree:root()
end end
@ -88,7 +76,7 @@ local function hash_node(node)
node:symbol(), node:symbol(),
node:child_count(), node:child_count(),
node:type(), node:type(),
node:range() node:range(),
}, ',') }, ',')
end end
@ -97,7 +85,7 @@ end
--- @return Range4? --- @return Range4?
local is_valid = cache.memoize(function(node, query) local is_valid = cache.memoize(function(node, query)
local bufnr = api.nvim_get_current_buf() local bufnr = api.nvim_get_current_buf()
local range --[[@type Range4]] = {node:range()} local range = { node:range() } --- @type Range4
range[3] = range[1] range[3] = range[1]
range[4] = -1 range[4] = -1
@ -138,15 +126,15 @@ local function get_text_for_range(range)
range[4] = -1 range[4] = -1
end end
local lines = api.nvim_buf_get_text(0, range[1], 0, range[3], range[4], {}) local lines = api.nvim_buf_get_text(0, range[1], 0, range[3], range[4], {})
if lines == nil then if not lines then
return nil, nil return
end end
local start_row = range[1] local start_row = range[1]
local end_row = range[3] local end_row = range[3]
local end_col = range[4] local end_col = range[4]
lines = vim.list_slice(lines, 1, end_row - start_row+1) lines = vim.list_slice(lines, 1, end_row - start_row + 1)
lines[#lines] = lines[#lines]:sub(1, end_col) lines[#lines] = lines[#lines]:sub(1, end_col)
if #lines > config.multiline_threshold then if #lines > config.multiline_threshold then
@ -155,9 +143,7 @@ local function get_text_for_range(range)
end_col = #lines[1] end_col = #lines[1]
end end
range = {start_row, 0, end_row, end_col} return lines, { start_row, 0, end_row, end_col }
return lines, range
end end
-- Merge lines, removing the indentation after 1st line -- Merge lines, removing the indentation after 1st line
@ -176,12 +162,16 @@ end
--- @return integer[] --- @return integer[]
local function get_indents(lines) local function get_indents(lines)
--- @type integer[] --- @type integer[]
--- @diagnostic disable-next-line local indents = vim.tbl_map(
local indents = vim.tbl_map(function(line) --- @param line string
--- @type string? --- @return integer
local indent = line:match(INDENT_PATTERN) function(line)
return indent and #indent or 0 --- @type string?
end, lines) local indent = line:match(INDENT_PATTERN)
return indent and #indent or 0
end,
lines
)
-- Dont skip first line indentation -- Dont skip first line indentation
indents[1] = 0 indents[1] = 0
return indents return indents
@ -189,14 +179,14 @@ end
--- @return integer --- @return integer
local function get_gutter_width() local function get_gutter_width()
return vim.fn.getwininfo(vim.api.nvim_get_current_win())[1].textoff return fn.getwininfo(api.nvim_get_current_win())[1].textoff
end end
local cursor_moved_vertical --[[@type fun(): boolean]] local cursor_moved_vertical --- @type fun(): boolean
do do
local line --[[@type integer]] local line --- @type integer?
cursor_moved_vertical = function() cursor_moved_vertical = function()
local newline = vim.api.nvim_win_get_cursor(0)[1] local newline = api.nvim_win_get_cursor(0)[1]
if newline ~= line then if newline ~= line then
line = newline line = newline
return true return true
@ -239,7 +229,7 @@ end
--- @return integer --- @return integer
local function display_window(bufnr, winid, width, height, col, ty, hl) local function display_window(bufnr, winid, width, height, col, ty, hl)
if not winid or not api.nvim_win_is_valid(winid) then if not winid or not api.nvim_win_is_valid(winid) then
local sep = config.separator and { config.separator, "TreesitterContextSeparator" } or nil local sep = config.separator and { config.separator, 'TreesitterContextSeparator' } or nil
winid = api.nvim_open_win(bufnr, false, { winid = api.nvim_open_win(bufnr, false, {
relative = 'win', relative = 'win',
width = width, width = width,
@ -250,7 +240,7 @@ local function display_window(bufnr, winid, width, height, col, ty, hl)
style = 'minimal', style = 'minimal',
noautocmd = true, noautocmd = true,
zindex = config.zindex, zindex = config.zindex,
border = sep and {'', '', '', '', sep, sep, sep, ''} or nil, border = sep and { '', '', '', '', sep, sep, sep, '' } or nil,
}) })
vim.w[winid][ty] = true vim.w[winid][ty] = true
vim.wo[winid].wrap = false vim.wo[winid].wrap = false
@ -282,7 +272,7 @@ local function get_node_parents(node)
--- @type TSNode[] --- @type TSNode[]
local parents = {} local parents = {}
while node ~= nil do while node ~= nil do
parents[#parents+1] = node parents[#parents + 1] = node
node = node:parent() node = node:parent()
end end
return parents return parents
@ -293,7 +283,9 @@ local function get_pos()
--- @type integer, integer --- @type integer, integer
local lnum, col local lnum, col
if config.mode == 'topline' then if config.mode == 'topline' then
lnum, col = vim.fn.line('w0') --[[@as integer]], 0 lnum, col =
fn.line('w0'), --[[@as integer]]
0
else -- default to 'cursor' else -- default to 'cursor'
lnum, col = unpack(api.nvim_win_get_cursor(0)) --[[@as integer]] lnum, col = unpack(api.nvim_win_get_cursor(0)) --[[@as integer]]
end end
@ -350,7 +342,7 @@ local function get_parent_matches(max_lines)
last_matches = parent_matches last_matches = parent_matches
parent_matches = {} parent_matches = {}
local last_row = -1 local last_row = -1
local topline = vim.fn.line('w0') local topline = fn.line('w0')
-- save nodes in a table to iterate from top to bottom -- save nodes in a table to iterate from top to bottom
local parents = get_node_parents(node) local parents = get_node_parents(node)
@ -365,7 +357,7 @@ local function get_parent_matches(max_lines)
if row == last_row then if row == last_row then
parent_matches[#parent_matches] = range parent_matches[#parent_matches] = range
else else
parent_matches[#parent_matches+1] = range parent_matches[#parent_matches + 1] = range
last_row = row last_row = row
local new_height = math.min(max_lines, #parent_matches) local new_height = math.min(max_lines, #parent_matches)
@ -379,11 +371,7 @@ local function get_parent_matches(max_lines)
until config.mode ~= 'topline' or #last_matches >= #parent_matches until config.mode ~= 'topline' or #last_matches >= #parent_matches
if config.trim_scope == 'inner' then if config.trim_scope == 'inner' then
return vim.list_slice( return vim.list_slice(parent_matches, 1, math.min(#parent_matches, max_lines))
parent_matches,
1,
math.min(#parent_matches, max_lines)
)
else -- default to 'outer' else -- default to 'outer'
return vim.list_slice( return vim.list_slice(
parent_matches, parent_matches,
@ -394,10 +382,10 @@ local function get_parent_matches(max_lines)
end end
--- @generic F: function --- @generic F: function
--- @param fn F --- @param f F
--- @param ms? number --- @param ms? number
--- @return F --- @return F
local function throttle(fn, ms) local function throttle(f, ms)
ms = ms or 200 ms = ms or 200
local timer = assert(vim.loop.new_timer()) local timer = assert(vim.loop.new_timer())
local waiting = 0 local waiting = 0
@ -407,30 +395,32 @@ local function throttle(fn, ms)
return return
end end
waiting = 0 waiting = 0
fn() -- first call, execute immediately f() -- first call, execute immediately
timer:start(ms, 0, function() timer:start(ms, 0, function()
if waiting > 1 then if waiting > 1 then
vim.schedule(fn) -- only execute if there are calls waiting vim.schedule(f) -- only execute if there are calls waiting
end end
end) end)
end end
end end
local function win_close(winid)
if winid ~= nil and api.nvim_win_is_valid(winid) then
api.nvim_win_close(winid, true)
end
end
local function close() local function close()
previous_nodes = nil previous_nodes = nil
-- Can't close other windows when the command-line window is open -- Can't close other windows when the command-line window is open
if vim.fn.getcmdwintype() ~= '' then if fn.getcmdwintype() ~= '' then
return return
end end
if context_winid ~= nil and api.nvim_win_is_valid(context_winid) then win_close(context_winid)
api.nvim_win_close(context_winid, true)
end
context_winid = nil context_winid = nil
if gutter_winid and api.nvim_win_is_valid(gutter_winid) then win_close(gutter_winid)
api.nvim_win_close(gutter_winid, true)
end
gutter_winid = nil gutter_winid = nil
end end
@ -504,8 +494,10 @@ local function highlight_contexts(bufnr, ctx_bufnr, contexts)
for capture, node in query:iter_captures(root, bufnr, start_row, end_row + 1) do for capture, node in query:iter_captures(root, bufnr, start_row, end_row + 1) do
local node_start_row, node_start_col, node_end_row, node_end_col = node:range() local node_start_row, node_start_col, node_end_row, node_end_col = node:range()
if node_end_row > end_row or if
(node_end_row == end_row and node_end_col > end_col and end_col ~= -1) then node_end_row > end_row
or (node_end_row == end_row and node_end_col > end_col and end_col ~= -1)
then
break break
end end
@ -526,7 +518,7 @@ local function highlight_contexts(bufnr, ctx_bufnr, contexts)
api.nvim_buf_set_extmark(ctx_bufnr, ns, row, node_start_col + offset, { api.nvim_buf_set_extmark(ctx_bufnr, ns, row, node_start_col + offset, {
end_line = row, end_line = row,
end_col = node_end_col + offset, end_col = node_end_col + offset,
hl_group = buf_query.hl_cache[capture] hl_group = buf_query.hl_cache[capture],
}) })
end end
end end
@ -543,8 +535,9 @@ end
--- @param width integer --- @param width integer
--- @return string, StatusLineHighlight[]? --- @return string, StatusLineHighlight[]?
local function build_lno_str(win, lnum, relnum, width) local function build_lno_str(win, lnum, relnum, width)
local has_col, statuscol = pcall(api.nvim_get_option_value, 'statuscolumn', {win = win, scope = "local"}) local has_col, statuscol =
if has_col and statuscol and statuscol ~= "" then pcall(api.nvim_get_option_value, 'statuscolumn', { win = win, scope = 'local' })
if has_col and statuscol and statuscol ~= '' then
local ok, data = pcall(api.nvim_eval_statusline, statuscol, { local ok, data = pcall(api.nvim_eval_statusline, statuscol, {
winid = win, winid = win,
use_statuscol_lnum = lnum, use_statuscol_lnum = lnum,
@ -557,7 +550,7 @@ local function build_lno_str(win, lnum, relnum, width)
if relnum then if relnum then
lnum = relnum lnum = relnum
end end
return string.format('%'..width..'d', lnum) return string.format('%' .. width .. 'd', lnum)
end end
--- @param buf integer --- @param buf integer
@ -571,23 +564,29 @@ local function highlight_lno_str(buf, text, highlights)
if col ~= endcol then if col ~= endcol then
api.nvim_buf_set_extmark(buf, ns, line - 1, col, { api.nvim_buf_set_extmark(buf, ns, line - 1, col, {
end_col = endcol, end_col = endcol,
hl_group=hl.group:find("LineNr") and "TreesitterContextLineNumber" or hl.group hl_group = hl.group:find('LineNr') and 'TreesitterContextLineNumber' or hl.group,
}) })
end end
end end
end end
api.nvim_buf_set_extmark(buf, ns, #text-1, 0, {end_line=#text, hl_group='TreesitterContextBottom', hl_eol=true}) api.nvim_buf_set_extmark(
buf,
ns,
#text - 1,
0,
{ end_line = #text, hl_group = 'TreesitterContextBottom', hl_eol = true }
)
end end
--- @param ctx_node_line_num integer --- @param ctx_node_line_num integer
--- @return integer --- @return integer
local function get_relative_line_num(ctx_node_line_num) local function get_relative_line_num(ctx_node_line_num)
local cursor_line_num = vim.fn.line('.') local cursor_line_num = fn.line('.')
local num_folded_lines = 0 local num_folded_lines = 0
-- Find all folds between the context node and the cursor -- Find all folds between the context node and the cursor
local current_line = ctx_node_line_num local current_line = ctx_node_line_num
while current_line < cursor_line_num do while current_line < cursor_line_num do
local fold_end = vim.fn.foldclosedend(current_line) local fold_end = fn.foldclosedend(current_line)
if fold_end == -1 then if fold_end == -1 then
current_line = current_line + 1 current_line = current_line + 1
else else
@ -602,12 +601,12 @@ local function horizontal_scroll_contexts()
if context_winid == nil then if context_winid == nil then
return return
end end
local active_win_view = vim.fn.winsaveview() local active_win_view = fn.winsaveview()
local context_win_view = api.nvim_win_call(context_winid, vim.fn.winsaveview) local context_win_view = api.nvim_win_call(context_winid, fn.winsaveview)
if active_win_view.leftcol ~= context_win_view.leftcol then if active_win_view.leftcol ~= context_win_view.leftcol then
context_win_view.leftcol = active_win_view.leftcol context_win_view.leftcol = active_win_view.leftcol
api.nvim_win_call(context_winid, function() api.nvim_win_call(context_winid, function()
return vim.fn.winrestview({leftcol = context_win_view.leftcol}) return fn.winrestview({ leftcol = context_win_view.leftcol })
end) end)
end end
end end
@ -623,27 +622,39 @@ local function open(ctx_ranges)
local win = api.nvim_get_current_win() local win = api.nvim_get_current_win()
local gutter_width = get_gutter_width() local gutter_width = get_gutter_width()
local win_width = math.max(1, api.nvim_win_get_width(0) - gutter_width) local win_width = math.max(1, api.nvim_win_get_width(0) - gutter_width)
local win_height = math.max(1, #ctx_ranges) local win_height = math.max(1, #ctx_ranges)
local gbufnr, ctx_bufnr = get_bufs() local gbufnr, ctx_bufnr = get_bufs()
if config.line_numbers and (vim.wo.number or vim.wo.relativenumber) then if config.line_numbers and (vim.wo.number or vim.wo.relativenumber) then
gutter_winid = display_window( gutter_winid = display_window(
gbufnr, gutter_winid, gutter_width, win_height, 0, gbufnr,
'treesitter_context_line_number', 'TreesitterContextLineNumber') gutter_winid,
gutter_width,
win_height,
0,
'treesitter_context_line_number',
'TreesitterContextLineNumber'
)
end end
context_winid = display_window( context_winid = display_window(
ctx_bufnr, context_winid, win_width, win_height, gutter_width, ctx_bufnr,
'treesitter_context', 'TreesitterContext') context_winid,
win_width,
win_height,
gutter_width,
'treesitter_context',
'TreesitterContext'
)
-- Set text -- Set text
local context_text --[[@type string[] ]] = {} local context_text = {} --- @type string[]
local lno_text --[[@type string[] ]] = {} local lno_text = {} --- @type string[]
local lno_highlights --[[@type StatusLineHighlight[][] ]] = {} local lno_highlights = {} --- @type StatusLineHighlight[][]
local contexts --[[@type Context[] ]] = {} local contexts = {} --- @type Context[]
for _, range0 in ipairs(ctx_ranges) do for _, range0 in ipairs(ctx_ranges) do
local lines, range = get_text_for_range(range0) local lines, range = get_text_for_range(range0)
@ -652,7 +663,7 @@ local function open(ctx_ranges)
end end
local text = merge_lines(lines) local text = merge_lines(lines)
contexts[#contexts+1] = { contexts[#contexts + 1] = {
lines = lines, lines = lines,
range = range, range = range,
indents = get_indents(lines), indents = get_indents(lines),
@ -661,11 +672,11 @@ local function open(ctx_ranges)
table.insert(context_text, text) table.insert(context_text, text)
local ctx_line_num = range[1] + 1 local ctx_line_num = range[1] + 1
local relnum --- @type integer? local relnum --- @type integer?
if vim.wo[win].relativenumber then if vim.wo[win].relativenumber then
relnum = get_relative_line_num(ctx_line_num) relnum = get_relative_line_num(ctx_line_num)
end end
local txt, hl = build_lno_str(win, ctx_line_num, relnum, gutter_width-1) local txt, hl = build_lno_str(win, ctx_line_num, relnum, gutter_width - 1)
table.insert(lno_text, txt) table.insert(lno_text, txt)
table.insert(lno_highlights, hl) table.insert(lno_highlights, hl)
end end
@ -682,7 +693,13 @@ local function open(ctx_ranges)
highlight_contexts(bufnr, ctx_bufnr, contexts) highlight_contexts(bufnr, ctx_bufnr, contexts)
api.nvim_buf_set_extmark(ctx_bufnr, ns, #lno_text-1, 0, {end_line=#lno_text, hl_group='TreesitterContextBottom', hl_eol=true}) api.nvim_buf_set_extmark(
ctx_bufnr,
ns,
#lno_text - 1,
0,
{ end_line = #lno_text, hl_group = 'TreesitterContextBottom', hl_eol = true }
)
end end
--- @param config_max integer --- @param config_max integer
@ -691,8 +708,8 @@ local function calc_max_lines(config_max)
local max_lines = config_max local max_lines = config_max
max_lines = max_lines == 0 and -1 or max_lines max_lines = max_lines == 0 and -1 or max_lines
local wintop = vim.fn.line('w0') local wintop = fn.line('w0')
local cursor = vim.fn.line('.') local cursor = fn.line('.')
local max_from_cursor = cursor - wintop local max_from_cursor = cursor - wintop
if config.separator and max_from_cursor > 0 then if config.separator and max_from_cursor > 0 then
@ -790,10 +807,10 @@ function M.enable()
end end
end) end)
autocmd({'BufLeave', 'WinLeave'}, close) autocmd({ 'BufLeave', 'WinLeave' }, close)
autocmd('User', {close , pattern = 'SessionSavePre' }) autocmd('User', { close, pattern = 'SessionSavePre' })
autocmd('User', {update, pattern = 'SessionSavePost' }) autocmd('User', { update, pattern = 'SessionSavePost' })
update() update()
enabled = true enabled = true
@ -832,7 +849,7 @@ function M.setup(options)
end end
function M.go_to_context() function M.go_to_context()
local line = vim.api.nvim_win_get_cursor(0)[1] local line = api.nvim_win_get_cursor(0)[1]
local context = nil local context = nil
local bufnr = api.nvim_get_current_buf() local bufnr = api.nvim_get_current_buf()
local contexts = all_contexts[bufnr] or {} local contexts = all_contexts[bufnr] or {}
@ -847,17 +864,17 @@ function M.go_to_context()
return return
end end
vim.api.nvim_win_set_cursor(0, { context.range[1] + 1, context.range[2] }) api.nvim_win_set_cursor(0, { context.range[1] + 1, context.range[2] })
end end
command('TSContextEnable' , M.enable , {}) command('TSContextEnable', M.enable, {})
command('TSContextDisable', M.disable, {}) command('TSContextDisable', M.disable, {})
command('TSContextToggle' , M.toggle , {}) command('TSContextToggle', M.toggle, {})
api.nvim_set_hl(0, 'TreesitterContext', {link = 'NormalFloat', default = true}) api.nvim_set_hl(0, 'TreesitterContext', { link = 'NormalFloat', default = true })
api.nvim_set_hl(0, 'TreesitterContextLineNumber', {link = 'LineNr', default = true}) api.nvim_set_hl(0, 'TreesitterContextLineNumber', { link = 'LineNr', default = true })
api.nvim_set_hl(0, 'TreesitterContextBottom', {link = 'NONE', default = true}) api.nvim_set_hl(0, 'TreesitterContextBottom', { link = 'NONE', default = true })
api.nvim_set_hl(0, 'TreesitterContextSeparator', {link = 'FloatBorder', default = true}) api.nvim_set_hl(0, 'TreesitterContextSeparator', { link = 'FloatBorder', default = true })
-- Setup with default options if user didn't call setup() -- Setup with default options if user didn't call setup()
autocmd_for_group('treesitter_context')('VimEnter', function() autocmd_for_group('treesitter_context')('VimEnter', function()