1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
|
local buffer = require("nvim-surround.buffer")
local M = {}
-- Converts a 1D index into the buffer to the corresponding 2D buffer position.
---@param index integer The index of the character in the string.
---@return position @The position of the character in the buffer.
---@nodiscard
M.index_to_pos = function(index)
local buffer_text = table.concat(buffer.get_lines(1, -1), "\n")
-- Counts the number of newline characters, plus one for the final character before the current line
local lnum = select(2, buffer_text:sub(1, math.max(1, index - 1)):gsub("\n", "\n")) + 1
-- Special case for first line, as there are no newline characters preceding it
if lnum == 1 then
return { 1, index }
end
local col = index - #table.concat(buffer.get_lines(1, lnum - 1), "\n") - 1
return { lnum, col }
end
-- Converts a 2D position in the buffer to the corresponding 1D string index.
---@param pos position The position in the buffer.
---@return integer @The index of the character into the buffer.
---@nodiscard
M.pos_to_index = function(pos)
-- Special case for first line, as there are no newline characters preceding it
if pos[1] == 1 then
return pos[2]
end
return #table.concat(buffer.get_lines(1, pos[1] - 1), "\n") + pos[2] + 1
end
-- Expands a selection to properly contain multi-byte characters.
---@param selection selection The given selection.
---@return selection @The adjusted selection, handling multi-byte characters.
---@nodiscard
M.adjust_selection = function(selection)
selection.first_pos = buffer.get_first_byte(selection.first_pos)
selection.last_pos = buffer.get_last_byte(selection.last_pos)
return selection
end
-- Returns a selection in the buffer based on a Lua pattern.
---@param find string The Lua pattern to find in the buffer.
---@return selection|nil @The closest selection matching the pattern, if any.
---@nodiscard
M.get_selection = function(find)
-- Get the current cursor position, buffer contents
local curpos = buffer.get_curpos()
local buffer_text = table.concat(buffer.get_lines(1, -1), "\n")
-- Find which character the cursor is in the file
local cursor_index = M.pos_to_index(curpos)
-- Find the character positions of the pattern in the file (after/on the cursor)
local a_first, a_last = buffer_text:find(find, cursor_index)
-- Find the character positions of the pattern in the file (before the cursor)
local b_first, b_last
-- Linewise search for the pattern before/on the cursor
for lnum = curpos[1], 1, -1 do
-- Get the file contents from the first line to current line
local cur_text = table.concat(buffer.get_lines(1, lnum - 1), "\n")
-- Find the character positions of the pattern in the file (before the cursor)
b_first, b_last = buffer_text:find(find, #cur_text + 1)
if b_first and b_first <= cursor_index then
break
end
end
-- If no match found, return the after one, if it exists
if not b_first or not b_last then
return a_first
and a_last
and M.adjust_selection({
first_pos = M.index_to_pos(a_first),
last_pos = M.index_to_pos(a_last),
})
end
-- Adjust the selection character-wise
local start_col, end_col = cursor_index, b_first
b_first, b_last = nil, nil
for index = start_col, end_col, -1 do
local c_first, c_last = buffer_text:find(find, index)
-- Validate if there is a current match
if c_last then
-- If no match yet or the current match is "better", use the current match
if
not (b_first and b_last) -- No match yet
or (b_last == c_last) -- Extending current match
or (cursor_index < b_first and c_first < b_first) -- Current is closer to cursor, after case
or (b_last < cursor_index and b_last < c_last) -- Current is closer to cursor, before case
then
b_first, b_last = c_first, c_last
end
end
end
-- If the cursor is inside the range then return it
if b_last and b_first and b_last >= cursor_index then
return M.adjust_selection({
first_pos = M.index_to_pos(b_first),
last_pos = M.index_to_pos(b_last),
})
end
-- Else if there's a range found after the cursor, return it
if a_first and a_last then
return M.adjust_selection({
first_pos = M.index_to_pos(a_first),
last_pos = M.index_to_pos(a_last),
})
end
-- Otherwise return the range found before the cursor, if one exists
if b_first and b_last then
return M.adjust_selection({
first_pos = M.index_to_pos(b_first),
last_pos = M.index_to_pos(b_last),
})
end
end
-- Finds the start and end indices for the given match groups.
---@param selection selection The parent selection encompassing the delimiter pair.
---@param pattern string The given Lua pattern to extract match groups from.
---@return selections|nil @The selections for the left and right delimiters.
---@nodiscard
M.get_selections = function(selection, pattern)
local offset = M.pos_to_index(selection.first_pos)
local str = table.concat(buffer.get_text(selection), "\n")
-- Get the surrounding pair, and the start/end indices
local ok, _, left_delimiter, first_index, right_delimiter, last_index = str:find(pattern)
-- Validate that a match was found
if not ok then
return nil
end
-- Validate that all four match groups are present
if not last_index then
vim.notify(
"Four match groups must be present in the Lua pattern, see :h nvim-surround.config.get_selections().",
vim.log.levels.ERROR
)
return nil
end
-- Validate that the second and fourth match groups are empty
if type(first_index) ~= "number" or type(last_index) ~= "number" then
vim.notify(
"The second and last capture groups must be empty, see :h nvim-surround.config.get_selections().",
vim.log.levels.ERROR
)
return nil
end
-- If delimiter does not exist, set the length to zero
local left_len = type(left_delimiter) == "string" and #left_delimiter or 0
local right_len = type(right_delimiter) == "string" and #right_delimiter or 0
-- If the left or right delimiters are empty, return the equivalent of an empty selection
local selections = {
---@cast first_index integer
---@cast last_index integer
left = M.adjust_selection({
first_pos = M.index_to_pos(offset + first_index - left_len - 1),
last_pos = M.index_to_pos(offset + first_index - 2),
}),
right = M.adjust_selection({
first_pos = M.index_to_pos(offset + last_index - right_len - 1),
last_pos = M.index_to_pos(offset + last_index - 2),
}),
}
-- Handle special case where the column is invalid
if selections.left.last_pos[2] > #buffer.get_line(selections.left.last_pos[1]) then
selections.left.last_pos[1] = selections.left.last_pos[1] + 1
selections.left.last_pos[2] = 0
end
if selections.right.last_pos[2] > #buffer.get_line(selections.right.last_pos[1]) then
selections.right.last_pos[1] = selections.right.last_pos[1] + 1
selections.right.last_pos[2] = 0
end
return selections
end
return M
|