summaryrefslogtreecommitdiff
path: root/stowables-dotlocal/share/nvim/site/pack/manual/start/nvim-surround-v2.1.7/lua/nvim-surround/patterns.lua
blob: 024998d03e02ea21bae159a4ae4c6ffbe2185d4f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
local buffer = require("nvim-surround.buffer")

local M = {}

-- Converts a 1D index into the buffer to the corresponding 2D buffer position.
---@param index integer The index of the character in the string.
---@return position @The position of the character in the buffer.
---@nodiscard
M.index_to_pos = function(index)
    local buffer_text = table.concat(buffer.get_lines(1, -1), "\n")
    -- Counts the number of newline characters, plus one for the final character before the current line
    local lnum = select(2, buffer_text:sub(1, math.max(1, index - 1)):gsub("\n", "\n")) + 1
    -- Special case for first line, as there are no newline characters preceding it
    if lnum == 1 then
        return { 1, index }
    end
    local col = index - #table.concat(buffer.get_lines(1, lnum - 1), "\n") - 1
    return { lnum, col }
end

-- Converts a 2D position in the buffer to the corresponding 1D string index.
---@param pos position The position in the buffer.
---@return integer @The index of the character into the buffer.
---@nodiscard
M.pos_to_index = function(pos)
    -- Special case for first line, as there are no newline characters preceding it
    if pos[1] == 1 then
        return pos[2]
    end
    return #table.concat(buffer.get_lines(1, pos[1] - 1), "\n") + pos[2] + 1
end

-- Expands a selection to properly contain multi-byte characters.
---@param selection selection The given selection.
---@return selection @The adjusted selection, handling multi-byte characters.
---@nodiscard
M.adjust_selection = function(selection)
    selection.first_pos = buffer.get_first_byte(selection.first_pos)
    selection.last_pos = buffer.get_last_byte(selection.last_pos)
    return selection
end

-- Returns a selection in the buffer based on a Lua pattern.
---@param find string The Lua pattern to find in the buffer.
---@return selection|nil @The closest selection matching the pattern, if any.
---@nodiscard
M.get_selection = function(find)
    -- Get the current cursor position, buffer contents
    local curpos = buffer.get_curpos()
    local buffer_text = table.concat(buffer.get_lines(1, -1), "\n")
    -- Find which character the cursor is in the file
    local cursor_index = M.pos_to_index(curpos)
    -- Find the character positions of the pattern in the file (after/on the cursor)
    local a_first, a_last = buffer_text:find(find, cursor_index)
    -- Find the character positions of the pattern in the file (before the cursor)
    local b_first, b_last
    -- Linewise search for the pattern before/on the cursor
    for lnum = curpos[1], 1, -1 do
        -- Get the file contents from the first line to current line
        local cur_text = table.concat(buffer.get_lines(1, lnum - 1), "\n")
        -- Find the character positions of the pattern in the file (before the cursor)
        b_first, b_last = buffer_text:find(find, #cur_text + 1)
        if b_first and b_first <= cursor_index then
            break
        end
    end
    -- If no match found, return the after one, if it exists
    if not b_first or not b_last then
        return a_first
            and a_last
            and M.adjust_selection({
                first_pos = M.index_to_pos(a_first),
                last_pos = M.index_to_pos(a_last),
            })
    end
    -- Adjust the selection character-wise
    local start_col, end_col = cursor_index, b_first
    b_first, b_last = nil, nil
    for index = start_col, end_col, -1 do
        local c_first, c_last = buffer_text:find(find, index)
        -- Validate if there is a current match
        if c_last then
            -- If no match yet or the current match is "better", use the current match
            if
                not (b_first and b_last) -- No match yet
                or (b_last == c_last) -- Extending current match
                or (cursor_index < b_first and c_first < b_first) -- Current is closer to cursor, after case
                or (b_last < cursor_index and b_last < c_last) -- Current is closer to cursor, before case
            then
                b_first, b_last = c_first, c_last
            end
        end
    end
    -- If the cursor is inside the range then return it
    if b_last and b_first and b_last >= cursor_index then
        return M.adjust_selection({
            first_pos = M.index_to_pos(b_first),
            last_pos = M.index_to_pos(b_last),
        })
    end
    -- Else if there's a range found after the cursor, return it
    if a_first and a_last then
        return M.adjust_selection({
            first_pos = M.index_to_pos(a_first),
            last_pos = M.index_to_pos(a_last),
        })
    end
    -- Otherwise return the range found before the cursor, if one exists
    if b_first and b_last then
        return M.adjust_selection({
            first_pos = M.index_to_pos(b_first),
            last_pos = M.index_to_pos(b_last),
        })
    end
end

-- Finds the start and end indices for the given match groups.
---@param selection selection The parent selection encompassing the delimiter pair.
---@param pattern string The given Lua pattern to extract match groups from.
---@return selections|nil @The selections for the left and right delimiters.
---@nodiscard
M.get_selections = function(selection, pattern)
    local offset = M.pos_to_index(selection.first_pos)
    local str = table.concat(buffer.get_text(selection), "\n")
    -- Get the surrounding pair, and the start/end indices
    local ok, _, left_delimiter, first_index, right_delimiter, last_index = str:find(pattern)
    -- Validate that a match was found
    if not ok then
        return nil
    end
    -- Validate that all four match groups are present
    if not last_index then
        vim.notify(
            "Four match groups must be present in the Lua pattern, see :h nvim-surround.config.get_selections().",
            vim.log.levels.ERROR
        )
        return nil
    end
    -- Validate that the second and fourth match groups are empty
    if type(first_index) ~= "number" or type(last_index) ~= "number" then
        vim.notify(
            "The second and last capture groups must be empty, see :h nvim-surround.config.get_selections().",
            vim.log.levels.ERROR
        )
        return nil
    end

    -- If delimiter does not exist, set the length to zero
    local left_len = type(left_delimiter) == "string" and #left_delimiter or 0
    local right_len = type(right_delimiter) == "string" and #right_delimiter or 0
    -- If the left or right delimiters are empty, return the equivalent of an empty selection
    local selections = {
        ---@cast first_index integer
        ---@cast last_index integer
        left = M.adjust_selection({
            first_pos = M.index_to_pos(offset + first_index - left_len - 1),
            last_pos = M.index_to_pos(offset + first_index - 2),
        }),
        right = M.adjust_selection({
            first_pos = M.index_to_pos(offset + last_index - right_len - 1),
            last_pos = M.index_to_pos(offset + last_index - 2),
        }),
    }
    -- Handle special case where the column is invalid
    if selections.left.last_pos[2] > #buffer.get_line(selections.left.last_pos[1]) then
        selections.left.last_pos[1] = selections.left.last_pos[1] + 1
        selections.left.last_pos[2] = 0
    end
    if selections.right.last_pos[2] > #buffer.get_line(selections.right.last_pos[1]) then
        selections.right.last_pos[1] = selections.right.last_pos[1] + 1
        selections.right.last_pos[2] = 0
    end
    return selections
end

return M