1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
|
local utils = require("nvim-surround.utils")
local ts_query = require("nvim-treesitter.query")
local ts_utils = require("nvim-treesitter.ts_utils")
local ts_parsers = require("nvim-treesitter.parsers")
local M = {}
-- Retrieves the node that corresponds exactly to a given selection.
---@param selection selection The given selection.
---@return _ @The corresponding node.
---@nodiscard
M.get_node = function(selection)
-- Convert the selection into a list
local range = {
selection.first_pos[1],
selection.first_pos[2],
selection.last_pos[1],
selection.last_pos[2],
}
-- Get the root node of the current tree
local lang_tree = ts_parsers.get_parser(0)
local tree = lang_tree:trees()[1]
local root = tree:root()
-- DFS through the tree and find all nodes that have the given type
local stack = { root }
while #stack > 0 do
local cur = stack[#stack]
-- If the current node's range is equal to the desired selection, return the node
if vim.deep_equal(range, { ts_utils.get_vim_range({ cur:range() }) }) then
return cur
end
-- Pop off of the stack
stack[#stack] = nil
-- Add the current node's children to the stack
for child in cur:iter_children() do
stack[#stack + 1] = child
end
end
return nil
end
-- Filters an existing parent selection down to a capture.
---@param sexpr string The given S-expression containing the capture.
---@param capture string The name of the capture to be returned.
---@param parent_selection selection The parent selection to be filtered down.
M.filter_selection = function(sexpr, capture, parent_selection)
local parent_node = M.get_node(parent_selection)
local range = { ts_utils.get_vim_range({ parent_node:range() }) }
local lang_tree = ts_parsers.get_parser(0)
local ok, parsed_query = pcall(function()
return vim.treesitter.query.parse and vim.treesitter.query.parse(lang_tree:lang(), sexpr)
or vim.treesitter.parse_query(lang_tree:lang(), sexpr)
end)
if not ok or not parent_node then
return {}
end
for id, node in parsed_query:iter_captures(parent_node, 0, 0, -1) do
local name = parsed_query.captures[id]
if name == capture then
range = { ts_utils.get_vim_range({ node:range() }) }
return {
first_pos = { range[1], range[2] },
last_pos = { range[3], range[4] },
}
end
end
return nil
end
-- Finds the nearest selection of a given query capture and its source.
---@param capture string The capture to be retrieved.
---@param type string The type of query to get the capture from.
---@return selection|nil @The selection of the capture.
---@nodiscard
M.get_selection = function(capture, type)
-- Get a table of all nodes that match the query
local table_list = ts_query.get_capture_matches_recursively(0, capture, type)
-- Convert the list of nodes into a list of selections
local selections_list = {}
for _, tab in ipairs(table_list) do
local range = { ts_utils.get_vim_range({ tab.node:range() }) }
selections_list[#selections_list + 1] = {
left = {
first_pos = { range[1], range[2] },
last_pos = { range[3], range[4] },
},
right = {
first_pos = { range[3], range[4] + 1 },
last_pos = { range[3], range[4] },
},
}
end
-- Filter out the best pair of selections from the list
local best_selections = utils.filter_selections_list(selections_list)
return best_selections
and {
first_pos = best_selections.left.first_pos,
last_pos = best_selections.right.last_pos,
}
end
return M
|