Skip to content

Python API Reference

This page documents the main Python modules in BERA Tools using mkdocstrings.

Core Algorithms

Copyright (C) 2025 Applied Geospatial Research Group.

This script is licensed under the GNU General Public License v3.0. See https://gnu.org/licenses/gpl-3.0 for full license details.

Author: Richard Zeng

Description

This script is part of the BERA Tools. Webpage: https://github.com/appliedgrg/beratools

This file is intended to be hosting algorithms and utility functions/classes for centerline tool.

CenterlineParams

Bases: float, Enum

Parameters for centerline generation.

These parameters are used to control the behavior of centerline generation and should be adjusted based on the specific requirements of the application.

Source code in beratools/core/algo_centerline.py
37
38
39
40
41
42
43
44
45
46
47
48
49
class CenterlineParams(float, enum.Enum):
    """
    Parameters for centerline generation.

    These parameters are used to control the behavior of centerline generation
    and should be adjusted based on the specific requirements of the application.
    """

    BUFFER_CLIP = 5.0
    SEGMENTIZE_LENGTH = 1.0
    SIMPLIFY_LENGTH = 0.5
    SMOOTH_SIGMA = 0.8
    CLEANUP_POLYGON_BY_AREA = 1.0

CenterlineStatus

Bases: IntEnum

Status of centerline generation.

This enum is used to indicate the status of centerline generation. It can be used to track the success or failure of the centerline generation process.

Source code in beratools/core/algo_centerline.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
@enum.unique
class CenterlineStatus(enum.IntEnum):
    """
    Status of centerline generation.

    This enum is used to indicate the status of centerline generation.
    It can be used to track the success or failure of the centerline generation process.

    """

    SUCCESS = 1
    FAILED = 2
    REGENERATE_SUCCESS = 3
    REGENERATE_FAILED = 4

SeedLine

Class to store seed line and least cost path.

Source code in beratools/core/algo_centerline.py
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
class SeedLine:
    """Class to store seed line and least cost path."""

    def __init__(self, line_gdf, ras_file, proc_segments, line_radius):
        self.line = line_gdf
        self.raster = ras_file
        self.line_radius = line_radius
        self.lc_path = None
        self.centerline = None
        self.corridor_poly_gpd = None

    def compute(self):
        line = self.line.geometry[0]
        line_radius = self.line_radius
        in_raster = self.raster
        seed_line = line  # LineString
        default_return = (seed_line, seed_line, None)

        ras_clip, out_meta = bt_common.clip_raster(in_raster, seed_line, line_radius)
        cost_clip, _ = algo_cost.cost_raster(ras_clip, out_meta)

        lc_path = line
        try:
            if bt_const.CenterlineFlags.USE_SKIMAGE_GRAPH:
                lc_path = bt_dijkstra.find_least_cost_path_skimage(cost_clip, out_meta, seed_line)
            else:
                lc_path = bt_dijkstra.find_least_cost_path(cost_clip, out_meta, seed_line)
        except Exception as e:
            print(e)
            return default_return

        if lc_path:
            lc_path_coords = lc_path.coords
        else:
            lc_path_coords = []

        self.lc_path = lc_path

        # search for centerline
        if len(lc_path_coords) < 2:
            print("No least cost path detected, use input line.")
            self.line["cl_status"] = CenterlineStatus.FAILED.value
            return default_return

        # get corridor raster
        lc_path = sh_geom.LineString(lc_path_coords)
        ras_clip, out_meta = bt_common.clip_raster(in_raster, lc_path, line_radius * 0.9)
        cost_clip, _ = algo_cost.cost_raster(ras_clip, out_meta)

        out_transform = out_meta["transform"]
        transformer = rasterio.transform.AffineTransformer(out_transform)
        cell_size = (out_transform[0], -out_transform[4])

        x1, y1 = lc_path_coords[0]
        x2, y2 = lc_path_coords[-1]
        source = [transformer.rowcol(x1, y1)]
        destination = [transformer.rowcol(x2, y2)]
        corridor_thresh_cl = algo_common.corridor_raster(
            cost_clip,
            out_meta,
            source,
            destination,
            cell_size,
            bt_const.FP_CORRIDOR_THRESHOLD,
        )

        # find contiguous corridor polygon and extract centerline
        df = gpd.GeoDataFrame(geometry=[seed_line], crs=out_meta["crs"])
        corridor_poly_gpd = find_corridor_polygon(corridor_thresh_cl, out_transform, df)
        center_line, status = find_centerline(corridor_poly_gpd.geometry.iloc[0], lc_path)
        self.line["cl_status"] = status.value

        self.lc_path = self.line.copy()
        self.lc_path.geometry = [lc_path]

        self.centerline = self.line.copy()
        self.centerline.geometry = [center_line]

        self.corridor_poly_gpd = corridor_poly_gpd

centerline_is_valid(centerline, input_line)

Check if centerline is valid.

Parameters:

Name Type Description Default
centerline _type_

description

required
input_line LineString

Seed line or least cost path.

required

Returns:

Name Type Description
bool

True if line is valid

Source code in beratools/core/algo_centerline.py
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def centerline_is_valid(centerline, input_line):
    """
    Check if centerline is valid.

    Args:
        centerline (_type_): _description_
        input_line (sh_geom.LineString): Seed line or least cost path.
        Only two end points are used.

    Returns:
        bool: True if line is valid

    """
    if not centerline:
        return False

    # centerline length less the half of least cost path
    if (
        centerline.length < input_line.length / 2
        or centerline.distance(sh_geom.Point(input_line.coords[0])) > bt_const.BT_EPSILON
        or centerline.distance(sh_geom.Point(input_line.coords[-1])) > bt_const.BT_EPSILON
    ):
        return False

    return True

find_centerline(poly, input_line)

Find centerline from polygon and input line.

Parameters:

Name Type Description Default
poly

sh_geom.Polygon

required
input_line LineString

Least cost path or seed line

required

Returns: centerline (sh_geom.LineString): Centerline status (CenterlineStatus): Status of centerline generation

Source code in beratools/core/algo_centerline.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def find_centerline(poly, input_line):
    """
    Find centerline from polygon and input line.

    Args:
        poly : sh_geom.Polygon
        input_line ( sh_geom.LineString): Least cost path or seed line

    Returns:
    centerline (sh_geom.LineString): Centerline
    status (CenterlineStatus): Status of centerline generation

    """
    default_return = input_line, CenterlineStatus.FAILED
    if not poly:
        print("find_centerline: No polygon found")
        return default_return

    poly = shapely.segmentize(poly, max_segment_length=CenterlineParams.SEGMENTIZE_LENGTH)

    # buffer to reduce MultiPolygons
    poly = poly.buffer(bt_const.SMALL_BUFFER)
    if type(poly) is sh_geom.MultiPolygon:
        print("sh_geom.MultiPolygon encountered, skip.")
        return default_return

    exterior_pts = list(poly.exterior.coords)

    if bt_const.CenterlineFlags.DELETE_HOLES:
        poly = sh_geom.Polygon(exterior_pts)
    if bt_const.CenterlineFlags.SIMPLIFY_POLYGON:
        poly = poly.simplify(CenterlineParams.SIMPLIFY_LENGTH)

    line_coords = list(input_line.coords)

    # TODO add more code to filter Voronoi vertices
    src_geom = sh_geom.Point(line_coords[0]).buffer(CenterlineParams.BUFFER_CLIP * 3).intersection(poly)
    dst_geom = sh_geom.Point(line_coords[-1]).buffer(CenterlineParams.BUFFER_CLIP * 3).intersection(poly)
    src_geom = None
    dst_geom = None

    try:
        centerline = get_centerline(
            poly,
            segmentize_maxlen=1,
            max_points=3000,
            simplification=0.05,
            smooth_sigma=CenterlineParams.SMOOTH_SIGMA,
            max_paths=1,
            src_geom=src_geom,
            dst_geom=dst_geom,
        )
    except Exception as e:
        print(f"find_centerline: {e}")
        return default_return

    if not centerline:
        return default_return

    if type(centerline) is sh_geom.MultiLineString:
        if len(centerline.geoms) > 1:
            print(" Multiple centerline segments detected, no further processing.")
            return centerline, CenterlineStatus.SUCCESS  # TODO: inspect
        elif len(centerline.geoms) == 1:
            centerline = centerline.geoms[0]
        else:
            return default_return

    cl_coords = list(centerline.coords)

    # trim centerline at two ends
    head_buffer = sh_geom.Point(cl_coords[0]).buffer(CenterlineParams.BUFFER_CLIP)
    centerline = centerline.difference(head_buffer)

    end_buffer = sh_geom.Point(cl_coords[-1]).buffer(CenterlineParams.BUFFER_CLIP)
    centerline = centerline.difference(end_buffer)

    # No centerline detected, use input line instead.
    if not centerline:
        return default_return
    try:
        # Empty centerline detected, use input line instead.
        if centerline.is_empty:
            return default_return
    except Exception as e:
        print(f"find_centerline: {e}")

    centerline = snap_end_to_end(centerline, input_line)

    # Check centerline. If valid, regenerate by splitting polygon into two halves.
    if not centerline_is_valid(centerline, input_line):
        try:
            print("Regenerating line ...")
            centerline = regenerate_centerline(poly, input_line)
            return centerline, CenterlineStatus.REGENERATE_SUCCESS
        except Exception as e:
            print(f"find_centerline: {e}")
            return input_line, CenterlineStatus.REGENERATE_FAILED

    return centerline, CenterlineStatus.SUCCESS

process_single_centerline(row_and_path)

Find centerline.

Args: row_and_path (list of row (gdf and lc_path)): and least cost path first is GeoPandas row, second is input line, (least cost path)

Returns: row: GeoPandas row with centerline

Source code in beratools/core/algo_centerline.py
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
def process_single_centerline(row_and_path):
    """
    Find centerline.

    Args:
    row_and_path (list of row (gdf and lc_path)): and least cost path
    first is GeoPandas row, second is input line, (least cost path)

    Returns:
    row: GeoPandas row with centerline

    """
    row = row_and_path[0]
    lc_path = row_and_path[1]

    poly = row.geometry.iloc[0]
    centerline, status = find_centerline(poly, lc_path)
    row["centerline"] = centerline

    return row

regenerate_centerline(poly, input_line)

Regenerates centerline when initial poly is not valid.

Parameters:

Name Type Description Default
input_line LineString

Seed line or least cost path.

required

Returns:

Type Description

sh_geom.MultiLineString

Source code in beratools/core/algo_centerline.py
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
def regenerate_centerline(poly, input_line):
    """
    Regenerates centerline when initial poly is not valid.

    Args:
        input_line (sh_geom.LineString): Seed line or least cost path.
        Only two end points will be used

    Returns:
        sh_geom.MultiLineString

    """
    line_1 = sh_ops.substring(input_line, start_dist=0.0, end_dist=input_line.length / 2)
    line_2 = sh_ops.substring(input_line, start_dist=input_line.length / 2, end_dist=input_line.length)

    pts = shapely.force_2d(
        [
            sh_geom.Point(list(input_line.coords)[0]),
            sh_geom.Point(list(line_1.coords)[-1]),
            sh_geom.Point(list(input_line.coords)[-1]),
        ]
    )
    perp = algo_common.generate_perpendicular_line_precise(pts)

    # sh_geom.MultiPolygon is rare, but need to be dealt with
    # remove polygon of area less than CenterlineParams.CLEANUP_POLYGON_BY_AREA
    poly = poly.buffer(bt_const.SMALL_BUFFER)
    if type(poly) is sh_geom.MultiPolygon:
        poly_geoms = list(poly.geoms)
        poly_valid = [True] * len(poly_geoms)
        for i, item in enumerate(poly_geoms):
            if item.area < CenterlineParams.CLEANUP_POLYGON_BY_AREA:
                poly_valid[i] = False

        poly_geoms = list(compress(poly_geoms, poly_valid))
        if len(poly_geoms) != 1:  # still multi polygon
            print("regenerate_centerline: Multi or none polygon found, pass.")

        poly = sh_geom.Polygon(poly_geoms[0])

    poly_exterior = sh_geom.Polygon(poly.buffer(bt_const.SMALL_BUFFER).exterior)
    poly_split = sh_ops.split(poly_exterior, perp)

    if len(poly_split.geoms) < 2:
        print("regenerate_centerline: polygon sh_ops.split failed, pass.")
        return None

    poly_1 = poly_split.geoms[0]
    poly_2 = poly_split.geoms[1]

    # find polygon and line pairs
    pair_line_1 = line_1
    pair_line_2 = line_2
    if not poly_1.intersects(line_1):
        pair_line_1 = line_2
        pair_line_2 = line_1
    elif poly_1.intersection(line_1).length < line_1.length / 3:
        pair_line_1 = line_2
        pair_line_2 = line_1

    center_line_1 = find_centerline(poly_1, pair_line_1)
    center_line_2 = find_centerline(poly_2, pair_line_2)

    center_line_1 = center_line_1[0]
    center_line_2 = center_line_2[0]

    if not center_line_1 or not center_line_2:
        print("Regenerate line: centerline is None")
        return None

    try:
        if center_line_1.is_empty or center_line_2.is_empty:
            print("Regenerate line: centerline is empty")
            return None
    except Exception as e:
        print(f"regenerate_centerline: {e}")

    print("Centerline is regenerated.")
    return sh_ops.linemerge(sh_geom.MultiLineString([center_line_1, center_line_2]))

Copyright (C) 2025 Applied Geospatial Research Group.

This script is licensed under the GNU General Public License v3.0. See https://gnu.org/licenses/gpl-3.0 for full license details.

Author: Richard Zeng

Description

This script is part of the BERA Tools. Webpage: https://github.com/appliedgrg/beratools

The purpose of this script is to provide common algorithms and utility functions/classes.

clean_geometries(gdf)

Remove rows with invalid, None, or empty geometries from the GeoDataFrame.

Parameters:

Name Type Description Default
gdf GeoDataFrame

The GeoDataFrame to clean.

required

Returns:

Name Type Description
GeoDataFrame

The cleaned GeoDataFrame with valid, non-null,

and non-empty geometries.

Source code in beratools/core/algo_common.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def clean_geometries(gdf):
    """
    Remove rows with invalid, None, or empty geometries from the GeoDataFrame.

    Args:
        gdf (GeoDataFrame): The GeoDataFrame to clean.

    Returns:
        GeoDataFrame: The cleaned GeoDataFrame with valid, non-null,
        and non-empty geometries.

    """
    # Remove rows where the geometry is invalid, None, or empty
    gdf = gdf[gdf.geometry.is_valid]  # Only keep valid geometries
    gdf = gdf[~gdf.geometry.isna()]  # Remove rows with None geometries
    gdf = gdf[gdf.geometry.apply(lambda geom: not geom.is_empty)]  # Remove empty geometries
    return gdf

clean_line_geometries(line_gdf)

Clean line geometries in the GeoDataFrame.

Source code in beratools/core/algo_common.py
120
121
122
123
124
125
126
127
128
129
130
def clean_line_geometries(line_gdf):
    """Clean line geometries in the GeoDataFrame."""
    if line_gdf is None:
        return line_gdf

    if line_gdf.empty:
        return line_gdf

    line_gdf = line_gdf[~line_gdf.geometry.isna() & ~line_gdf.geometry.is_empty]
    line_gdf = line_gdf[line_gdf.geometry.length > bt_const.SMALL_BUFFER]
    return line_gdf

corridor_raster(raster_clip, out_meta, source, destination, cell_size, corridor_threshold)

Calculate corridor raster.

Parameters:

Name Type Description Default
raster_clip raster
required
out_meta

raster file meta

required
source list of point tuple(s)

start point in row/col

required
destination list of point tuple(s)

end point in row/col

required
cell_size tuple

(cell_size_x, cell_size_y)

required

Returns: corridor raster

Source code in beratools/core/algo_common.py
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
def corridor_raster(raster_clip, out_meta, source, destination, cell_size, corridor_threshold):
    """
    Calculate corridor raster.

    Args:
        raster_clip (raster):
        out_meta : raster file meta
        source (list of point tuple(s)): start point in row/col
        destination (list of point tuple(s)): end point in row/col
        cell_size (tuple): (cell_size_x, cell_size_y)
        corridor_threshold (double)

    Returns:
    corridor raster

    """
    try:
        # change all nan to BT_NODATA_COST for workaround
        if len(raster_clip.shape) > 2:
            raster_clip = np.squeeze(raster_clip, axis=0)

        algo_cost.remove_nan_from_array_refactor(raster_clip)

        # generate the cost raster to source point
        mcp_source = sk_graph.MCP_Geometric(raster_clip, sampling=cell_size)
        source_cost_acc = mcp_source.find_costs(source)[0]
        del mcp_source

        # # # generate the cost raster to destination point
        mcp_dest = sk_graph.MCP_Geometric(raster_clip, sampling=cell_size)
        dest_cost_acc = mcp_dest.find_costs(destination)[0]

        # Generate corridor
        corridor = source_cost_acc + dest_cost_acc
        corridor = np.ma.masked_invalid(corridor)

        # Calculate minimum value of corridor raster
        if np.ma.min(corridor) is not None:
            corr_min = float(np.ma.min(corridor))
        else:
            corr_min = 0.5

        # normalize corridor raster by deducting corr_min
        corridor_norm = corridor - corr_min
        corridor_thresh_cl = np.ma.where(corridor_norm >= corridor_threshold, 1.0, 0.0)

    except Exception as e:
        print(e)
        print("corridor_raster: Exception occurred.")
        return None

    return corridor_thresh_cl

generate_perpendicular_line_precise(points, offset=20)

Generate a perpendicular line to the input line at the given point.

Parameters:

Name Type Description Default
points list[Point]

The points where to generate the perpendicular lines.

required
offset float

The length of the perpendicular line.

20

Returns:

Type Description

shapely.geometry.LineString: The generated perpendicular line.

Source code in beratools/core/algo_common.py
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
def generate_perpendicular_line_precise(points, offset=20):
    """
    Generate a perpendicular line to the input line at the given point.

    Args:
        points (list[Point]): The points where to generate the perpendicular lines.
        offset (float): The length of the perpendicular line.

    Returns:
        shapely.geometry.LineString: The generated perpendicular line.

    """
    # Compute the angle of the line
    if len(points) not in [2, 3]:
        return None

    center = points[1]
    perp_line = None

    if len(points) == 2:
        head = points[0]
        tail = points[1]

        delta_x = head.x - tail.x
        delta_y = head.y - tail.y
        angle = 0.0

        if math.isclose(delta_x, 0.0):
            angle = math.pi / 2
        else:
            angle = math.atan(delta_y / delta_x)

        start = [center.x + offset / 2.0, center.y]
        end = [center.x - offset / 2.0, center.y]
        line = sh_geom.LineString([start, end])
        perp_line = sh_aff.rotate(line, angle + math.pi / 2.0, origin=center, use_radians=True)
    elif len(points) == 3:
        head = points[0]
        tail = points[2]

        angle_1 = _line_angle(center, head)
        angle_2 = _line_angle(center, tail)
        angle_diff = (angle_2 - angle_1) / 2.0
        head_new = sh_geom.Point(
            center.x + offset / 2.0 * math.cos(angle_1),
            center.y + offset / 2.0 * math.sin(angle_1),
        )
        if head.has_z:
            head_new = shapely.force_3d(head_new)
        try:
            perp_seg_1 = sh_geom.LineString([center, head_new])
            perp_seg_1 = sh_aff.rotate(perp_seg_1, angle_diff, origin=center, use_radians=True)
            perp_seg_2 = sh_aff.rotate(perp_seg_1, math.pi, origin=center, use_radians=True)
            perp_line = sh_geom.LineString([list(perp_seg_1.coords)[1], list(perp_seg_2.coords)[1]])
        except Exception as e:
            print(e)

    return perp_line

get_angle(line, vertex_index)

Calculate the angle of the first or last segment.

TODO: use np.arctan2 instead of np.arctan

Args: line: LineString end_index: 0 or -1 of the line vertices. Consider the multipart.

Source code in beratools/core/algo_common.py
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
def get_angle(line, vertex_index):
    """
    Calculate the angle of the first or last segment.

    # TODO: use np.arctan2 instead of np.arctan

    Args:
    line: LineString
    end_index: 0 or -1 of the line vertices. Consider the multipart.

    """
    pts = line_coord_list(line)

    if vertex_index == 0:
        pt_1 = pts[0]
        pt_2 = pts[1]
    elif vertex_index == -1:
        pt_1 = pts[-1]
        pt_2 = pts[-2]

    delta_x = pt_2.x - pt_1.x
    delta_y = pt_2.y - pt_1.y
    if np.isclose(pt_1.x, pt_2.x):
        angle = np.pi / 2
        if delta_y > 0:
            angle = np.pi / 2
        elif delta_y < 0:
            angle = -np.pi / 2
    else:
        angle = np.arctan(delta_y / delta_x)

        # arctan is in range [-pi/2, pi/2], regulate all angles to [[-pi/2, 3*pi/2]]
        if delta_x < 0:
            angle += np.pi  # the second or fourth quadrant

    return angle

has_multilinestring(gdf)

Check if any geometry is a MultiLineString.

Source code in beratools/core/algo_common.py
94
95
96
97
98
def has_multilinestring(gdf):
    """Check if any geometry is a MultiLineString."""
    # Filter out None values (invalid geometries) from the GeoDataFrame
    valid_geometries = gdf.geometry
    return any(isinstance(geom, sh_geom.MultiLineString) for geom in valid_geometries)

intersection_of_lines(line_1, line_2)

Only LINESTRING is dealt with for now.

Args: line_1 : line_2 :

Returns: sh_geom.Point: intersection point

Source code in beratools/core/algo_common.py
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
def intersection_of_lines(line_1, line_2):
    """
    Only LINESTRING is dealt with for now.

    Args:
    line_1 :
    line_2 :

    Returns:
    sh_geom.Point: intersection point

    """
    # intersection collection, may contain points and lines
    inter = None
    if line_1 and line_2:
        inter = line_1.intersection(line_2)

    # TODO: intersection may return GeometryCollection, LineString or MultiLineString
    if inter:
        if (
            type(inter) is sh_geom.GeometryCollection
            or type(inter) is sh_geom.LineString
            or type(inter) is sh_geom.MultiLineString
        ):
            return inter.centroid

    return inter

prepare_lines_gdf(file_path, layer=None, proc_segments=True)

Split lines at vertices or return original rows.

It handles for MultiLineString.

Source code in beratools/core/algo_common.py
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def prepare_lines_gdf(file_path, layer=None, proc_segments=True):
    """
    Split lines at vertices or return original rows.

    It handles for MultiLineString.

    """
    # Check if there are any MultiLineString geometries
    gdf = read_geospatial_file(file_path, layer=layer)

    # Explode MultiLineStrings into individual LineStrings
    if has_multilinestring(gdf):
        gdf = gdf.explode(index_parts=False)

    split_gdf_list = []

    for row in gdf.itertuples(index=False):  # Use itertuples to iterate
        line = row.geometry  # Access geometry directly via the named tuple

        # If proc_segment is True, split the line at vertices
        if proc_segments:
            coords = list(line.coords)  # Extract the list of coordinates (vertices)

            # For each LineString, split the line into segments by the vertices
            for i in range(len(coords) - 1):
                segment = sh_geom.LineString([coords[i], coords[i + 1]])

                # Copy over all non-geometry columns (excluding 'geometry')
                attributes = {col: getattr(row, col) for col in gdf.columns if col != "geometry"}
                single_row_gdf = gpd.GeoDataFrame([attributes], geometry=[segment], crs=gdf.crs)
                split_gdf_list.append(single_row_gdf)

        else:
            # If not proc_segment, add the original row as a single-row GeoDataFrame
            attributes = {col: getattr(row, col) for col in gdf.columns if col != "geometry"}
            single_row_gdf = gpd.GeoDataFrame([attributes], geometry=[line], crs=gdf.crs)
            split_gdf_list.append(single_row_gdf)

    return split_gdf_list

process_single_item(cls_obj)

Process a class object for universal multiprocessing.

Parameters:

Name Type Description Default
cls_obj

Class object to be processed

required

Returns:

Name Type Description
cls_obj

Class object after processing

Source code in beratools/core/algo_common.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def process_single_item(cls_obj):
    """
    Process a class object for universal multiprocessing.

    Args:
        cls_obj: Class object to be processed

    Returns:
        cls_obj: Class object after processing

    """
    try:
        cls_obj.compute()
        return cls_obj
    except Exception as e:
        import traceback

        print(f"❌ Exception during compute() for object: {e}")
        traceback.print_exc()
        return None

read_geospatial_file(file_path, layer=None)

Read a geospatial file, clean the geometries and return a GeoDataFrame.

Parameters:

Name Type Description Default
file_path str

The path to the geospatial file (e.g., .shp, .gpkg).

required
layer str

The specific layer to read if the file is

None

Returns:

Name Type Description
GeoDataFrame

The cleaned GeoDataFrame containing the data from the file

with valid geometries only.

None

If there is an error reading the file or layer.

Source code in beratools/core/algo_common.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def read_geospatial_file(file_path, layer=None):
    """
    Read a geospatial file, clean the geometries and return a GeoDataFrame.

    Args:
        file_path (str): The path to the geospatial file (e.g., .shp, .gpkg).
        layer (str, optional): The specific layer to read if the file is
        multi-layered (e.g., GeoPackage).

    Returns:
        GeoDataFrame: The cleaned GeoDataFrame containing the data from the file
        with valid geometries only.
        None: If there is an error reading the file or layer.

    """
    try:
        if layer is None:
            # Read the file without specifying a layer
            gdf = gpd.read_file(file_path)
        else:
            # Read the file with the specified layer
            gdf = gpd.read_file(file_path, layer=layer)

        # Clean the geometries in the GeoDataFrame
        gdf = clean_geometries(gdf)
        gdf["BT_UID"] = range(len(gdf))  # assign temporary UID
        return gdf

    except Exception as e:
        print(f"Error reading file {file_path}: {e}")
        return None

save_raster_to_file(in_raster_mem, in_meta, out_raster_file)

Save raster matrix in memory to file.

Parameters:

Name Type Description Default
in_raster_mem

numpy raster

required
in_meta

input meta

required
out_raster_file

output raster file

required
Source code in beratools/core/algo_common.py
332
333
334
335
336
337
338
339
340
341
342
343
def save_raster_to_file(in_raster_mem, in_meta, out_raster_file):
    """
    Save raster matrix in memory to file.

    Args:
        in_raster_mem: numpy raster
        in_meta: input meta
        out_raster_file: output raster file

    """
    with rasterio.open(out_raster_file, "w", **in_meta) as dest:
        dest.write(in_raster_mem, indexes=1)

Copyright (C) 2025 Applied Geospatial Research Group.

This script is licensed under the GNU General Public License v3.0. See https://gnu.org/licenses/gpl-3.0 for full license details.

Author: Richard Zeng

Description

This script is part of the BERA Tools. Webpage: https://github.com/appliedgrg/beratools

This file hosts cost raster related functions.

circle_kernel_refactor(size, radius)

Create a circular kernel using Scipy.

Args: size : kernel size radius : radius of the circle

Returns: kernel (ndarray): A circular kernel.

Examples: kernel_scipy = create_circle_kernel_scipy(17, 8) will replicate xarray-spatial kernel cell_x = 0.3 cell_y = 0.3 tree_radius = 2.5 convolution.circle_kernel(cell_x, cell_y, tree_radius)

Source code in beratools/core/algo_cost.py
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
def circle_kernel_refactor(size, radius):
    """
    Create a circular kernel using Scipy.

    Args:
    size : kernel size
    radius : radius of the circle

    Returns:
    kernel (ndarray): A circular kernel.

    Examples:
    kernel_scipy = create_circle_kernel_scipy(17, 8)
    will replicate xarray-spatial kernel
    cell_x = 0.3
    cell_y = 0.3
    tree_radius = 2.5
    convolution.circle_kernel(cell_x, cell_y, tree_radius)

    """
    # Create grid points (mesh)
    y, x = np.ogrid[:size, :size]

    # Center of the kernel
    center_x, center_y = (size - 1) / 2, (size - 1) / 2

    # Calculate the distance from the center
    distance = np.sqrt((x - center_x) ** 2 + (y - center_y) ** 2)

    # Create a circular kernel
    kernel = distance <= radius
    return kernel.astype(float)

cost_norm_dist_transform(canopy_ndarray, max_line_dist, sampling)

Compute a distance-based cost map based on the proximity of valid data points.

Source code in beratools/core/algo_cost.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def cost_norm_dist_transform(canopy_ndarray, max_line_dist, sampling):
    """Compute a distance-based cost map based on the proximity of valid data points."""
    # Convert masked array to a regular array and fill the masked areas with np.nan
    in_ndarray = canopy_ndarray.filled(np.nan)

    # Compute the Euclidean distance transform (edt) where the valid values are
    euc_dist_array = scipy.ndimage.distance_transform_edt(
        np.logical_not(np.isnan(in_ndarray)), sampling=sampling
    )

    # Apply the mask back to set the distances to np.nan
    euc_dist_array[canopy_ndarray.mask] = np.nan

    # Calculate the smoothness (cost) array
    normalized_cost = float(max_line_dist) - euc_dist_array
    normalized_cost[normalized_cost <= 0.0] = 0.0
    smooth_cost_array = normalized_cost / float(max_line_dist)

    return smooth_cost_array

cost_raster(in_raster, meta, tree_radius=2.5, canopy_ht_threshold=2.5, max_line_dist=2.5, canopy_avoid=0.4, cost_raster_exponent=1.5)

General version of cost_raster.

To be merged later: variables and consistent nodata solution

Source code in beratools/core/algo_cost.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def cost_raster(
    in_raster,
    meta,
    tree_radius=2.5,
    canopy_ht_threshold=2.5,
    max_line_dist=2.5,
    canopy_avoid=0.4,
    cost_raster_exponent=1.5,
):
    """
    General version of cost_raster.

    To be merged later: variables and consistent nodata solution

    """
    if len(in_raster.shape) > 2:
        in_raster = np.squeeze(in_raster, axis=0)

    # regulate canopy_avoid between 0 and 1
    avoidance = max(0, min(1, canopy_avoid))
    cell_x, cell_y = meta["transform"][0], -meta["transform"][4]

    kernel_radius = int(tree_radius / cell_x)
    kernel = circle_kernel_refactor(2 * kernel_radius + 1, kernel_radius)
    dyn_canopy_ndarray = dyn_np_cc_map(in_raster, canopy_ht_threshold)

    cc_std, cc_mean = cost_focal_stats(dyn_canopy_ndarray, kernel)
    cc_smooth = cost_norm_dist_transform(dyn_canopy_ndarray, max_line_dist, [cell_x, cell_y])

    cost_clip = dyn_np_cost_raster_refactor(
        dyn_canopy_ndarray, cc_mean, cc_std, cc_smooth, avoidance, cost_raster_exponent
    )

    # TODO use nan or BT_DATA?
    cost_clip[in_raster == bt_const.BT_NODATA] = np.nan
    dyn_canopy_ndarray[in_raster == bt_const.BT_NODATA] = np.nan

    return cost_clip, dyn_canopy_ndarray

dyn_np_cc_map(in_chm, canopy_ht_threshold)

Create a new canopy raster.

MaskedArray based on the threshold comparison of in_chm (canopy height model) with canopy_ht_threshold. It assigns 1.0 where the condition is True (canopy) and 0.0 where the condition is False (non-canopy).

Source code in beratools/core/algo_cost.py
67
68
69
70
71
72
73
74
75
76
77
def dyn_np_cc_map(in_chm, canopy_ht_threshold):
    """
    Create a new canopy raster.

    MaskedArray based on the threshold comparison of in_chm (canopy height model)
    with canopy_ht_threshold. It assigns 1.0 where the condition is True (canopy)
    and 0.0 where the condition is False (non-canopy).

    """
    canopy_ndarray = np.ma.where(in_chm >= canopy_ht_threshold, 1.0, 0.0).astype(float)
    return canopy_ndarray

Least Cost Path Algorithm.

This algorithm is adapted from the QGIS plugin: Find the least cost path with given cost raster and points Original author: FlowMap Group@SESS.PKU Source code repository: https://github.com/Gooong/LeastCostPath

Copyright (C) 2023 by AppliedGRG Author: Richard Zeng Date: 2023-03-01

This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version.

MinCostPathHelper

Helper class for the cost matrix.

Source code in beratools/core/algo_dijkstra.py
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
class MinCostPathHelper:
    """Helper class for the cost matrix."""

    @staticmethod
    def _point_to_row_col(point_xy, ras_transform):
        col, row = ras_transform.rowcol(point_xy.x(), point_xy.y())

        return row, col

    @staticmethod
    def _row_col_to_point(row_col, ras_transform):
        x, y = ras_transform.xy(row_col[0], row_col[1])
        return x, y

    @staticmethod
    def create_points_from_path(ras_transform, min_cost_path, start_point, end_point):
        path_points = list(
            map(
                lambda row_col: MinCostPathHelper._row_col_to_point(row_col, ras_transform),
                min_cost_path,
            )
        )
        path_points[0] = (start_point.x, start_point.y)
        path_points[-1] = (end_point.x, end_point.y)
        return path_points

    @staticmethod
    def create_path_feature_from_points(path_points, attr_vals):
        path_points_raw = [[pt.x, pt.y] for pt in path_points]

        return sh_geom.LineString(path_points_raw), attr_vals

    @staticmethod
    def block2matrix_numpy(block, nodata):
        contains_negative = False
        with np.nditer(block, flags=["refs_ok"], op_flags=["readwrite"]) as it:
            for x in it:
                # TODO: this speeds up a lot, but need further inspection
                # if np.isclose(x, nodata) or np.isnan(x):
                if x <= nodata or np.isnan(x):
                    x[...] = 9999.0
                elif x < 0:
                    contains_negative = True

        return block, contains_negative

    @staticmethod
    def block2matrix(block, nodata):
        contains_negative = False
        width, height = block.shape
        # TODO: deal with nodata
        matrix = [
            [
                None
                if np.isclose(block[i][j], nodata) or np.isclose(block[i][j], bt_const.BT_NODATA)
                else block[i][j]
                for j in range(height)
            ]
            for i in range(width)
        ]

        for row in matrix:
            for v in row:
                if v is not None:
                    if v < 0 and not np.isclose(v, bt_const.BT_NODATA):
                        contains_negative = True

        return matrix, contains_negative

dijkstra_np(start_tuple, end_tuple, matrix)

Dijkstra's algorithm for finding the shortest path between two nodes in a graph.

Parameters:

Name Type Description Default
start_node list

[row,col] coordinates of the initial node

required
end_node list

[row,col] coordinates of the desired node

required
matrix array 2d

numpy array that contains matrix as 1s and free space as 0s

required

Returns:

Type Description

list[list]: list of list of nodes that form the shortest path

Source code in beratools/core/algo_dijkstra.py
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
def dijkstra_np(start_tuple, end_tuple, matrix):
    """
    Dijkstra's algorithm for finding the shortest path between two nodes in a graph.

    Args:
        start_node (list): [row,col] coordinates of the initial node
        end_node (list): [row,col] coordinates of the desired node
        matrix (array 2d): numpy array that contains matrix as 1s and free space as 0s

    Returns:
        list[list]: list of list of nodes that form the shortest path

    """
    # source and destination are free
    start_node = start_tuple[0]
    end_node = end_tuple[0]
    path = None
    costs = None

    try:
        matrix[start_node[0], start_node[1]] = 0
        matrix[end_node[0], end_node[1]] = 0

        path, cost = sk_graph.route_through_array(matrix, start_node, end_node)
        costs = [0.0 for i in range(len(path))]
    except Exception as e:
        print(f"dijkstra_np: {e}")
        return None

    return [(path, costs, end_tuple)]

valid_node(node, size_of_grid)

Check if node is within the grid boundaries.

Source code in beratools/core/algo_dijkstra.py
268
269
270
271
272
273
274
def valid_node(node, size_of_grid):
    """Check if node is within the grid boundaries."""
    if node[0] < 0 or node[0] >= size_of_grid:
        return False
    if node[1] < 0 or node[1] >= size_of_grid:
        return False
    return True

Copyright (C) 2025 Applied Geospatial Research Group.

This script is licensed under the GNU General Public License v3.0. See https://gnu.org/licenses/gpl-3.0 for full license details.

Author: Richard Zeng, Maverick Fong

Description

This script is part of the BERA Tools. Webpage: https://github.com/appliedgrg/beratools

The purpose of this script is to provide main interface for canopy footprint tool. The tool is used to generate the footprint of a line based on relative threshold.

BufferRing

Buffer ring class.

Source code in beratools/core/algo_footprint_rel.py
110
111
112
113
114
115
116
117
class BufferRing:
    """Buffer ring class."""

    def __init__(self, ring_poly, side):
        self.geometry = ring_poly
        self.side = side
        self.percentile = 0.5
        self.Dyn_Canopy_Threshold = 0.05

FootprintCanopy

Relative canopy footprint class.

Source code in beratools/core/algo_footprint_rel.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
class FootprintCanopy:
    """Relative canopy footprint class."""

    def __init__(self, in_geom, in_chm, in_layer=None):
        data = gpd.read_file(in_geom, layer=in_layer)
        self.lines = []

        for idx in data.index:
            line = LineInfo(data.iloc[[idx]], in_chm)
            self.lines.append(line)

    def compute(self, processes, parallel_mode=bt_const.ParallelMode.MULTIPROCESSING):
        result = bt_base.execute_multiprocessing(
            algo_common.process_single_item,
            self.lines,
            "Canopy Footprint",
            processes,
            1,
            parallel_mode,
        )

        fp = []
        percentile = []
        try:
            for item in result:
                if item.footprint is not None:
                    fp.append(item.footprint)
                else:
                    print("Footprint is None for one of the lines.")
                    continue  # Skip failed line

            if fp:
                self.footprints = pd.concat(fp)
            else:
                print("No valid footprints to save.")
                self.footprints = None

            for item in result:
                if item.lines_percentile is not None:
                    percentile.append(item.lines_percentile)
                else:
                    print("lines_percentile is None for one of the lines.")
                    continue  # Skip failed line

            if percentile:
                self.lines_percentile = pd.concat(percentile)
            else:
                print("No valid lines_percentile to save.")
                self.lines_percentile = None
        except Exception as e:
            print(f"Error during processing: {e}")

    def save_footprint(self, out_footprint, layer=None):
        if self.footprints is not None and isinstance(self.footprints, gpd.GeoDataFrame):
            self.footprints.to_file(out_footprint, layer=layer)
        else:
            print("No footprints to save (None or not a GeoDataFrame).")

    def save_line_percentile(self, out_percentile):
        if self.lines_percentile is not None and isinstance(self.lines_percentile, gpd.GeoDataFrame):
            self.lines_percentile.to_file(out_percentile)
        else:
            print("No lines_percentile to save (None or not a GeoDataFrame).")

LineInfo

Class to store line information.

Source code in beratools/core/algo_footprint_rel.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
class LineInfo:
    """Class to store line information."""

    def __init__(
        self,
        line_gdf,
        in_chm,
        max_ln_width=32,
        tree_radius=1.5,
        max_line_dist=1.5,
        canopy_avoidance=0.0,
        exponent=1.0,
        canopy_thresh_percentage=50,
    ):
        self.line = line_gdf
        self.in_chm = in_chm
        self.line_simp = self.line.geometry.simplify(tolerance=0.5, preserve_topology=True)

        self.canopy_percentile = 50
        self.DynCanTh = np.nan
        # chk_df_multipart
        # if proc_segments:
        # line_seg = split_into_segments(line_seg)

        self.buffer_rings = []

        self.CL_CutHt = np.nan
        self.CR_CutHt = np.nan
        self.RDist_Cut = np.nan
        self.LDist_Cut = np.nan

        self.canopy_thresh_percentage = canopy_thresh_percentage
        self.canopy_avoidance = canopy_avoidance
        self.exponent = exponent
        self.max_ln_width = max_ln_width
        self.max_line_dist = max_line_dist
        self.tree_radius = tree_radius

        self.nodata = -9999
        self.dyn_canopy_ndarray = None
        self.negative_cost_clip = None
        self.out_meta = None

        self.buffer_left = None
        self.buffer_right = None
        self.footprint = None

        self.lines_percentile = None

    def compute(self):
        self.prepare_ring_buffer()

        ring_list = []
        for item in self.buffer_rings:
            ring = self.cal_percentileRing(item)
            if ring is not None:
                ring_list.append(ring)
            else:
                print("Skipping invalid ring.")

        self.buffer_rings = ring_list

        # Aggregate percentiles and geometries for lines_percentile
        percentile_records = []
        for ring in self.buffer_rings:
            if ring is not None and hasattr(ring, "geometry") and hasattr(ring, "percentile"):
                percentile_records.append(
                    {"geometry": ring.geometry, "percentile": ring.percentile, "side": ring.side.value}
                )
        if percentile_records:
            self.lines_percentile = gpd.GeoDataFrame(percentile_records)
            self.lines_percentile.set_geometry("geometry", inplace=True)
            if self.line.crs:
                self.lines_percentile = self.lines_percentile.set_crs(self.line.crs, allow_override=True)
        else:
            self.lines_percentile = None

        self.rate_of_change(self.get_percentile_array(Side.left), Side.left)
        self.rate_of_change(self.get_percentile_array(Side.right), Side.right)

        self.line["CL_CutHt"] = self.CL_CutHt
        self.line["CR_CutHt"] = self.CR_CutHt
        self.line["RDist_Cut"] = self.RDist_Cut
        self.line["LDist_Cut"] = self.LDist_Cut

        self.DynCanTh = (self.CL_CutHt + self.CR_CutHt) / 2
        self.line["DynCanTh"] = self.DynCanTh

        self.prepare_line_buffer()

        fp_left = self.process_single_footprint(Side.left)
        fp_right = self.process_single_footprint(Side.right)

        # Check if footprints are valid
        if fp_left is None or fp_right is None:
            print("One or both footprints are None in LineInfo.")
            self.footprint = None
            return

        try:
            # Buffer cleanup for validity
            fp_left.geometry = fp_left.geometry.buffer(0)
            fp_right.geometry = fp_right.geometry.buffer(0)

            fp_combined = pd.concat([fp_left, fp_right])

            if fp_combined.empty or not isinstance(fp_combined, gpd.GeoDataFrame):
                print("Combined footprint is invalid or empty.")
                self.footprint = None
                return

            fp_combined = fp_combined.dissolve()
            fp_combined.geometry = fp_combined.geometry.buffer(-0.005)

            self.footprint = fp_combined
        except Exception as e:
            print(f"Error combining footprints: {e}")
            self.footprint = None
            return

        # Transfer group value to footprint if present
        if bt_const.BT_GROUP in self.line.columns:
            self.footprint[bt_const.BT_GROUP] = self.line[bt_const.BT_GROUP].iloc[0]

    def prepare_ring_buffer(self):
        nrings = 1
        ringdist = 15
        ring_list = self.multi_ring_buffer(self.line_simp, nrings, ringdist)
        for i in ring_list:
            if BufferRing(i, Side.left):
                self.buffer_rings.append(BufferRing(i, Side.left))
            else:
                print("Empty buffer ring")

        nrings = -1
        ringdist = -15
        ring_list = self.multi_ring_buffer(self.line_simp, nrings, ringdist)
        for i in ring_list:
            if BufferRing(i, Side.right):
                self.buffer_rings.append(BufferRing(i, Side.right))
            else:
                print("Empty buffer ring")

    def cal_percentileRing(self, ring):
        line_buffer = None
        try:
            line_buffer = ring.geometry
            if line_buffer.is_empty or shapely.is_missing(line_buffer):
                return None
            if line_buffer.has_z:
                line_buffer = sh_ops.transform(lambda x, y, z=None: (x, y), line_buffer)

        except Exception as e:
            print(f"cal_percentileRing: {e}")
            return None

        # TODO: temporary workaround for exception causing not percentile defined
        try:
            clipped_raster, _ = bt_common.clip_raster(self.in_chm, line_buffer, 0)
            clipped_raster = np.squeeze(clipped_raster, axis=0)

            # mask all -9999 (nodata) value cells
            masked_raster = np.ma.masked_where(clipped_raster == bt_const.BT_NODATA, clipped_raster)
            filled_raster = np.ma.filled(masked_raster, np.nan)

            # Calculate the percentile
            percentile = np.nanpercentile(filled_raster, 50)

            if percentile > 1:
                ring.Dyn_Canopy_Threshold = percentile * (0.3)
            else:
                ring.Dyn_Canopy_Threshold = 1

            ring.percentile = percentile
        except Exception as e:
            print(e)
            print("Default values are used.")

        return ring

    def get_percentile_array(self, side):
        per_array = []
        for item in self.buffer_rings:
            try:
                if item.side == side:
                    per_array.append(item.percentile)
            except Exception as e:
                print(e)

        return per_array

    def rate_of_change(self, percentile_array, side):
        # Since the x interval is 1 unit, the array 'diff' is the rate of change (slope)
        diff = np.ediff1d(percentile_array)
        cut_dist = len(percentile_array) / 5

        median_percentile = np.nanmedian(percentile_array)
        if not np.isnan(median_percentile):
            cut_percentile = float(math.floor(median_percentile))
        else:
            cut_percentile = 0.5

        found = False
        changes = 1.50
        Change = np.insert(diff, 0, 0)
        scale_down = 1.0

        # test the rate of change is > than 150% (1.5), if it is
        # no result found then lower to 140% (1.4) until 110% (1.1)
        try:
            while not found and changes >= 1.1:
                for ii in range(0, len(Change) - 1):
                    if percentile_array[ii] >= 0.5:
                        if (Change[ii]) >= changes:
                            cut_dist = (ii + 1) * scale_down
                            cut_percentile = math.floor(percentile_array[ii])

                            if 0.5 >= cut_percentile:
                                if cut_dist > 5:
                                    cut_percentile = 2
                                    cut_dist = cut_dist * scale_down**3
                                    # @<0.5  found and modified
                            elif 0.5 < cut_percentile <= 5.0:
                                if cut_dist > 6:
                                    cut_dist = cut_dist * scale_down**3  # 4.0
                                    # @0.5-5.0  found and modified
                            elif 5.0 < cut_percentile <= 10.0:
                                if cut_dist > 8:  # 5
                                    cut_dist = cut_dist * scale_down**3
                                    # @5-10  found and modified
                            elif 10.0 < cut_percentile <= 15:
                                if cut_dist > 5:
                                    cut_dist = cut_dist * scale_down**3  # 5.5
                                    #  @10-15  found and modified
                            elif 15 < cut_percentile:
                                if cut_dist > 4:
                                    cut_dist = cut_dist * scale_down**2
                                    cut_percentile = 15.5
                                    #  @>15  found and modified
                            found = True
                            # rate of change found
                            break
                changes = changes - 0.1

        except IndexError:
            pass

        # if still no result found, lower to 10% (1.1),
        # if no result found then default is used
        if not found:
            if 0.5 >= median_percentile:
                cut_dist = 4 * scale_down  # 3
                cut_percentile = 0.5
            elif 0.5 < median_percentile <= 5.0:
                cut_dist = 4.5 * scale_down  # 4.0
                cut_percentile = math.floor(median_percentile)
            elif 5.0 < median_percentile <= 10.0:
                cut_dist = 5.5 * scale_down  # 5
                cut_percentile = math.floor(median_percentile)
            elif 10.0 < median_percentile <= 15:
                cut_dist = 6 * scale_down  # 5.5
                cut_percentile = math.floor(median_percentile)
            elif 15 < median_percentile:
                cut_dist = 5 * scale_down  # 5
                cut_percentile = 15.5

        if side == Side.right:
            self.RDist_Cut = cut_dist
            self.CR_CutHt = float(cut_percentile)
        elif side == Side.left:
            self.LDist_Cut = cut_dist
            self.CL_CutHt = float(cut_percentile)

    def multi_ring_buffer(self, df, nrings, ringdist):
        """
        Buffers an input DataFrames geometry nring (number of rings) times.

        Compute with a distance between rings of ringdist and returns
        a list of non overlapping buffers
        """
        rings = []  # A list to hold the individual buffers
        line = df.geometry.iloc[0]
        # For each ring (1, 2, 3, ..., nrings)
        for ring in np.arange(0, ringdist, nrings):
            big_ring = line.buffer(
                nrings + ring, single_sided=True, cap_style="flat"
            )  # Create one big buffer
            small_ring = line.buffer(ring, single_sided=True, cap_style="flat")  # Create one smaller one
            the_ring = big_ring.difference(small_ring)  # Difference the big with the small to create a ring
            if (
                ~shapely.is_empty(the_ring)
                or ~shapely.is_missing(the_ring)
                or not None
                or ~the_ring.area == 0
            ):
                if isinstance(the_ring, sh_geom.MultiPolygon) or isinstance(the_ring, shapely.Polygon):
                    rings.append(the_ring)  # Append the ring to the rings list
                else:
                    if isinstance(the_ring, shapely.GeometryCollection):
                        for i in range(0, len(the_ring.geoms)):
                            if not isinstance(the_ring.geoms[i], shapely.LineString):
                                rings.append(the_ring.geoms[i])

        return rings  # return the list

    def prepare_line_buffer(self):
        line = self.line.geometry.iloc[0]
        buffer_left_1 = line.buffer(
            distance=self.max_ln_width + 1,
            cap_style=3,
            single_sided=True,
        )

        buffer_left_2 = line.buffer(
            distance=-1,
            cap_style=3,
            single_sided=True,
        )

        self.buffer_left = sh_ops.unary_union([buffer_left_1, buffer_left_2])

        buffer_right_1 = line.buffer(
            distance=-self.max_ln_width - 1,
            cap_style=3,
            single_sided=True,
        )
        buffer_right_2 = line.buffer(distance=1, cap_style=3, single_sided=True)

        self.buffer_right = sh_ops.unary_union([buffer_right_1, buffer_right_2])

    def dyn_canopy_cost_raster(self, side):
        in_chm_raster = self.in_chm
        # tree_radius = self.tree_radius
        # max_line_dist = self.max_line_dist
        # canopy_avoid = self.canopy_avoidance
        # exponent = self.exponent
        line_df = self.line
        out_meta = self.out_meta

        canopy_thresh_percentage = self.canopy_thresh_percentage / 100

        Cut_Dist = None
        line_buffer = None
        if side == Side.left:
            canopy_ht_threshold = line_df.CL_CutHt * canopy_thresh_percentage
            Cut_Dist = self.LDist_Cut
            line_buffer = self.buffer_left
        elif side == Side.right:
            canopy_ht_threshold = line_df.CR_CutHt * canopy_thresh_percentage
            Cut_Dist = self.RDist_Cut
            line_buffer = self.buffer_right
        else:
            canopy_ht_threshold = 0.5
            Cut_Dist = 1.0
            line_buffer = None

        canopy_ht_threshold = float(canopy_ht_threshold)
        if canopy_ht_threshold <= 0:
            canopy_ht_threshold = 0.5

        # get the round up integer number for tree search radius
        # tree_radius = float(tree_radius)
        # max_line_dist = float(max_line_dist)
        # canopy_avoid = float(canopy_avoid)
        # cost_raster_exponent = float(exponent)

        try:
            clipped_rasterC, out_meta = bt_common.clip_raster(in_chm_raster, line_buffer, 0)
            negative_cost_clip, dyn_canopy_ndarray = algo_cost.cost_raster(
                clipped_rasterC,
                out_meta,
                self.tree_radius,
                canopy_ht_threshold,
                self.max_line_dist,
                self.canopy_avoidance,
                self.exponent,
            )

            return dyn_canopy_ndarray, negative_cost_clip, out_meta, Cut_Dist

        except Exception as e:
            print(f"dyn_canopy_cost_raster: {e}")
            return None

    def process_single_footprint(self, side):
        # this will change segment content, and parameters will be changed
        in_canopy_r, in_cost_r, in_meta, Cut_Dist = self.dyn_canopy_cost_raster(side)

        if np.isnan(in_canopy_r).all():
            print("Canopy raster empty")

        if np.isnan(in_cost_r).all():
            print("Cost raster empty")

        exp_shk_cell = self.exponent  # TODO: duplicate vars
        no_data = self.nodata

        shapefile_proj = self.line.crs
        in_transform = in_meta["transform"]

        segment_list = []

        feat = self.line.geometry.iloc[0]
        if hasattr(feat, "geoms"):
            for geom in feat.geoms:
                for coord in geom.coords:
                    segment_list.append(coord)
        else:
            for coord in feat.coords:
                segment_list.append(coord)

        cell_size_x = in_transform[0]
        cell_size_y = -in_transform[4]

        # Work out the corridor from both end of the centerline
        try:
            if len(in_cost_r.shape) > 2:
                in_cost_r = np.squeeze(in_cost_r, axis=0)

            algo_cost.remove_nan_from_array_refactor(in_cost_r)
            in_cost_r[in_cost_r == no_data] = np.inf

            # generate 1m interval points along line
            distance_delta = 1
            distances = np.arange(0, feat.length, distance_delta)
            multipoint_along_line = [feat.interpolate(distance) for distance in distances]
            multipoint_along_line.append(sh_geom.Point(segment_list[-1]))
            # Rasterize points along line
            rasterized_points_Alongln = ras_feat.rasterize(
                multipoint_along_line,
                out_shape=in_cost_r.shape,
                transform=in_transform,
                fill=0,
                all_touched=True,
                default_value=1,
            )
            points_Alongln = np.transpose(np.nonzero(rasterized_points_Alongln))

            # Find minimum cost paths through an N-d costs array.
            mcp_flexible1 = MCP_Flexible(in_cost_r, sampling=(cell_size_x, cell_size_y), fully_connected=True)
            flex_cost_alongLn, flex_back_alongLn = mcp_flexible1.find_costs(starts=points_Alongln)

            # Generate corridor
            corridor = flex_cost_alongLn
            corridor = np.ma.masked_invalid(corridor)

            # Calculate minimum value of corridor raster
            if np.ma.min(corridor) is not None:
                corr_min = float(np.ma.min(corridor))
            else:
                corr_min = 0.5

            # normalize corridor raster by deducting corr_min
            corridor_norm = corridor - corr_min

            # Set minimum as zero and save minimum file
            corridor_th_value = Cut_Dist / cell_size_x
            if corridor_th_value < 0:  # if no threshold found, use default value
                corridor_th_value = bt_const.FP_CORRIDOR_THRESHOLD / cell_size_x

            corridor_thresh = np.ma.where(corridor_norm >= corridor_th_value, 1.0, 0.0)
            clean_raster = algo_common.morph_raster(corridor_thresh, in_canopy_r, exp_shk_cell, cell_size_x)

            # create mask for non-polygon area
            mask = np.where(clean_raster == 1, True, False)
            if clean_raster.dtype == np.int64:
                clean_raster = clean_raster.astype(np.int32)

            # Process: ndarray to shapely Polygon
            out_polygon = ras_feat.shapes(clean_raster, mask=mask, transform=in_transform)

            # create a shapely MultiPolygon
            multi_polygon = []
            if out_polygon is not None:
                try:
                    for poly, value in out_polygon:
                        multi_polygon.append(sh_geom.shape(poly))
                except TypeError:
                    pass

            if not multi_polygon:
                print("No polygons generated from raster. Returning None.")
                return None

            poly = sh_geom.MultiPolygon(multi_polygon) if multi_polygon else None

            # create GeoDataFrame directly from dictionary
            out_gdata = gpd.GeoDataFrame({"CorriThresh": [corridor_th_value], "geometry": [poly]})
            out_gdata.set_geometry("geometry", inplace=True)
            if shapefile_proj:
                out_gdata = out_gdata.set_crs(shapefile_proj, allow_override=True)

            if out_gdata is None or out_gdata.empty or out_gdata.geometry.isnull().all():
                print("Empty GeoDataFrame from process_single_footprint.")
                return None

            return out_gdata

        except Exception as e:
            print("Exception: {}".format(e))

multi_ring_buffer(df, nrings, ringdist)

Buffers an input DataFrames geometry nring (number of rings) times.

Compute with a distance between rings of ringdist and returns a list of non overlapping buffers

Source code in beratools/core/algo_footprint_rel.py
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
def multi_ring_buffer(self, df, nrings, ringdist):
    """
    Buffers an input DataFrames geometry nring (number of rings) times.

    Compute with a distance between rings of ringdist and returns
    a list of non overlapping buffers
    """
    rings = []  # A list to hold the individual buffers
    line = df.geometry.iloc[0]
    # For each ring (1, 2, 3, ..., nrings)
    for ring in np.arange(0, ringdist, nrings):
        big_ring = line.buffer(
            nrings + ring, single_sided=True, cap_style="flat"
        )  # Create one big buffer
        small_ring = line.buffer(ring, single_sided=True, cap_style="flat")  # Create one smaller one
        the_ring = big_ring.difference(small_ring)  # Difference the big with the small to create a ring
        if (
            ~shapely.is_empty(the_ring)
            or ~shapely.is_missing(the_ring)
            or not None
            or ~the_ring.area == 0
        ):
            if isinstance(the_ring, sh_geom.MultiPolygon) or isinstance(the_ring, shapely.Polygon):
                rings.append(the_ring)  # Append the ring to the rings list
            else:
                if isinstance(the_ring, shapely.GeometryCollection):
                    for i in range(0, len(the_ring.geoms)):
                        if not isinstance(the_ring.geoms[i], shapely.LineString):
                            rings.append(the_ring.geoms[i])

    return rings  # return the list

Side

Bases: Enum

Constants for left and right side.

Source code in beratools/core/algo_footprint_rel.py
38
39
40
41
42
class Side(Enum):
    """Constants for left and right side."""

    left = "left"
    right = "right"

line_footprint_rel(in_line, in_chm, out_footprint, processes, verbose=True, in_layer=None, out_layer=None, max_ln_width=32, tree_radius=1.5, max_line_dist=1.5, canopy_avoidance=0.0, exponent=1.0, canopy_thresh_percentage=50, parallel_mode=bt_const.ParallelMode.MULTIPROCESSING)

Safe version of relative canopy footprint tool.

Source code in beratools/core/algo_footprint_rel.py
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
def line_footprint_rel(
    in_line,
    in_chm,
    out_footprint,
    processes,
    verbose=True,
    in_layer=None,
    out_layer=None,
    max_ln_width=32,
    tree_radius=1.5,
    max_line_dist=1.5,
    canopy_avoidance=0.0,
    exponent=1.0,
    canopy_thresh_percentage=50,
    parallel_mode=bt_const.ParallelMode.MULTIPROCESSING,
):
    """Safe version of relative canopy footprint tool."""
    try:
        footprint = FootprintCanopy(in_line, in_chm, in_layer=in_layer)
    except Exception as e:
        print(f"Failed to initialize FootprintCanopy: {e}")
        return

    try:
        footprint.compute(processes, parallel_mode)
    except Exception as e:
        print(f"Error in compute(): {e}")
        import traceback

        traceback.print_exc()
        return

    # Save only if footprints were actually generated
    if (
        hasattr(footprint, "footprints")
        and footprint.footprints is not None
        and hasattr(footprint.footprints, "empty")
        and not footprint.footprints.empty
    ):
        try:
            footprint.save_footprint(out_footprint, out_layer)
            if verbose:
                print(f"Footprint saved to {out_footprint} layer={out_layer}")
        except Exception as e:
            print(f"Failed to save footprint: {e}")
    else:
        print("No valid footprints to save.")

    # Optionally save percentile lines (if needed)
    if (
        hasattr(footprint, "lines_percentile")
        and footprint.lines_percentile is not None
        and hasattr(footprint.lines_percentile, "empty")
        and not footprint.lines_percentile.empty
    ):
        out_percentile = out_footprint.replace("footprint", "line_percentile")
        try:
            footprint.save_line_percentile(out_percentile)
            if verbose:
                print(f"Line percentile saved to {out_percentile}")
        except Exception as e:
            print(f"Failed to save line percentile: {e}")

Copyright (C) 2025 Applied Geospatial Research Group.

This script is licensed under the GNU General Public License v3.0. See https://gnu.org/licenses/gpl-3.0 for full license details.

Author: Richard Zeng, Maverick Fong

Description

This script is part of the BERA Tools. Webpage: https://github.com/appliedgrg/beratools

This file hosts code to deal with line grouping and merging, cleanups.

LineGrouping

Class to group lines and merge them.

Source code in beratools/core/algo_line_grouping.py
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
class LineGrouping:
    """Class to group lines and merge them."""

    def __init__(self, in_line_gdf, merge_group=True, use_angle_grouping=True) -> None:
        if in_line_gdf is None:
            raise ValueError("Line GeoDataFrame cannot be None")
        self.use_angle_grouping = use_angle_grouping

        if in_line_gdf.empty:
            raise ValueError("Line GeoDataFrame cannot be empty")

        self.lines = algo_common.clean_line_geometries(in_line_gdf)
        self.lines.reset_index(inplace=True, drop=True)
        self.merge_group = merge_group

        self.sim_geom = self.lines.simplify(1)

        self.G = nk.Graph(len(self.lines))
        self.merged_vertex_list = []
        self.has_group_attr = False
        self.need_regrouping = False
        self.groups = [None] * len(self.lines)
        self.merged_lines_trimmed = None  # merged trimmed lines

        self.vertex_list = []
        self.vertex_of_concern = []
        self.v_index = None  # sindex of all vertices for vertex_list

        self.polys = None

        # invalid geoms in final geom list
        self.valid_lines = None
        self.valid_polys = None
        self.invalid_lines = None
        self.invalid_polys = None

    def create_vertex_list(self):
        # check if data has group column
        if bt_const.BT_GROUP in self.lines.keys():
            self.groups = self.lines[bt_const.BT_GROUP]
            self.has_group_attr = True
            if self.groups.hasnans:  # Todo: check for other invalid values
                self.need_regrouping = True

        for idx, s_geom, geom, group in zip(*zip(*self.sim_geom.items()), self.lines.geometry, self.groups):
            self.vertex_list.append(VertexNode(idx, geom, s_geom, 0, group))
            self.vertex_list.append(VertexNode(idx, geom, s_geom, -1, group))

        v_points = []
        for i in self.vertex_list:
            if i.vertex is None:
                print("Vertex is None, skipping.")
                continue

            v_points.append(i.vertex.buffer(SMALL_BUFFER))  # small polygon

        # Spatial index of all vertices
        self.v_index = shapely.STRtree(v_points)

        vertex_visited = [False] * len(self.vertex_list)
        for i, pt in enumerate(v_points):
            if vertex_visited[i]:
                continue

            s_list = self.v_index.query(pt)
            vertex = self.vertex_list[i]
            if len(s_list) > 1:
                for j in s_list:
                    if j != i:
                        # some short line will be very close to each other
                        if vertex.vertex.distance(self.vertex_list[j].vertex) > bt_const.SMALL_BUFFER:
                            continue

                        vertex.merge(self.vertex_list[j])
                        vertex_visited[j] = True

            self.merged_vertex_list.append(vertex)
            vertex_visited[i] = True

        for i in self.merged_vertex_list:
            i.check_connectivity(self.use_angle_grouping)

        for i in self.merged_vertex_list:
            if i.line_connected:
                for edge in i.line_connected:
                    self.G.addEdge(edge[0], edge[1])

    def group_lines(self):
        cc = nk.components.ConnectedComponents(self.G)
        cc.run()
        # print("number of components ", cc.numberOfComponents())

        group = 0
        for i in range(cc.numberOfComponents()):
            component = cc.getComponents()[i]
            for id in component:
                self.groups[id] = group

            group += 1

    def update_line_in_vertex_node(self, line_id, line):
        """Update line in VertexNode after trimming."""
        idx = self.v_index.query(line)
        for i in idx:
            v = self.vertex_list[i]
            v.update_line(line_id, line)

    def run_line_merge(self):
        return algo_merge_lines.run_line_merge(self.lines, self.merge_group)

    def find_vertex_for_poly_trimming(self):
        self.vertex_of_concern = [i for i in self.merged_vertex_list if i.vertex_class in CONCERN_CLASSES]

    def line_and_poly_cleanup(self):
        sindex_poly = self.polys.sindex

        for vertex in self.vertex_of_concern:
            s_idx = sindex_poly.query(vertex.vertex, predicate="within")
            if len(s_idx) == 0:
                continue

            #  Trim intersections of primary lines
            polys = self.polys.loc[s_idx].geometry
            if not self.merge_group:
                if (
                    vertex.vertex_class == VertexClass.FIVE_WAY_TWO_PRIMARY_LINE
                    or vertex.vertex_class == VertexClass.FIVE_WAY_ONE_PRIMARY_LINE
                    or vertex.vertex_class == VertexClass.FOUR_WAY_ONE_PRIMARY_LINE
                    or vertex.vertex_class == VertexClass.FOUR_WAY_TWO_PRIMARY_LINE
                    or vertex.vertex_class == VertexClass.THREE_WAY_ONE_PRIMARY_LINE
                ):
                    out_polys = vertex.trim_primary_end(polys)
                    if len(out_polys) == 0:
                        continue

                    # update polygon DataFrame
                    for idx, out_poly in out_polys:
                        if out_poly:
                            self.polys.at[idx, "geometry"] = out_poly

            # retrieve polygons again. Some polygons may be updated
            polys = self.polys.loc[s_idx]
            if (
                vertex.vertex_class == VertexClass.SINGLE_WAY
                or vertex.vertex_class == VertexClass.TWO_WAY_ZERO_PRIMARY_LINE
                or vertex.vertex_class == VertexClass.THREE_WAY_ZERO_PRIMARY_LINE
                or vertex.vertex_class == VertexClass.FOUR_WAY_ZERO_PRIMARY_LINE
                or vertex.vertex_class == VertexClass.FIVE_WAY_ZERO_PRIMARY_LINE
            ):
                if vertex.vertex_class == VertexClass.THREE_WAY_ZERO_PRIMARY_LINE:
                    pass

                out_polys = vertex.trim_end_all(polys)
                if len(out_polys) == 0:
                    continue

                # update polygon DataFrame
                for idx, out_poly in out_polys:
                    self.polys.at[idx, "geometry"] = out_poly

            polys = self.polys.loc[s_idx]
            if vertex.vertex_class != VertexClass.SINGLE_WAY:
                poly_trim_list = vertex.trim_intersection(polys, self.merge_group)
                for p_trim in poly_trim_list:
                    # update main line and polygon DataFrame
                    self.polys.at[p_trim.poly_index, "geometry"] = p_trim.poly_cleanup
                    self.lines.at[p_trim.line_index, "geometry"] = p_trim.line_cleanup

                    # update VertexNode's line
                    self.update_line_in_vertex_node(p_trim.line_index, p_trim.line_cleanup)

    def get_merged_lines_original(self):
        return self.lines.dissolve(by=bt_const.BT_GROUP)

    def run_grouping(self):
        self.create_vertex_list()
        if not self.has_group_attr:
            self.group_lines()

        self.find_vertex_for_poly_trimming()
        self.lines[bt_const.BT_GROUP] = self.groups  # assign group attribute

    def run_regrouping(self):
        """
        Run this when new lines are added to grouped file.

        Some new lines has empty group attributes
        """
        pass

    def run_cleanup(self, in_polys):
        self.polys = in_polys.copy()
        self.line_and_poly_cleanup()
        self.run_line_merge_trimmed()
        self.check_geom_validity()

    def run_line_merge_trimmed(self):
        self.merged_lines_trimmed = self.run_line_merge()

    def check_geom_validity(self):
        """
        Check MultiLineString and MultiPolygon in line and polygon dataframe.

        Save to separate layers for user to double check
        """
        #  remove null geometry
        # TODO make sure lines and polygons match in pairs
        # they should have same amount and spatial coverage
        self.valid_polys = self.polys[~self.polys.geometry.isna() & ~self.polys.geometry.is_empty]

        # save sh_geom.MultiLineString and sh_geom.MultiPolygon
        self.invalid_polys = self.polys[(self.polys.geometry.geom_type == "MultiPolygon")]

        # check lines
        self.valid_lines = self.merged_lines_trimmed[
            ~self.merged_lines_trimmed.geometry.isna() & ~self.merged_lines_trimmed.geometry.is_empty
        ]
        self.valid_lines.reset_index(inplace=True, drop=True)

        self.invalid_lines = self.merged_lines_trimmed[
            (self.merged_lines_trimmed.geometry.geom_type == "MultiLineString")
        ]
        self.invalid_lines.reset_index(inplace=True, drop=True)

    def save_file(self, out_file):
        if not self.valid_lines.empty:
            self.valid_lines["length"] = self.valid_lines.length
            self.valid_lines.to_file(out_file, layer="merged_lines")

        if not self.valid_polys.empty:
            if "length" in self.valid_polys.columns:
                self.valid_polys.drop(columns=["length"], inplace=True)

            self.valid_polys["area"] = self.valid_polys.area
            self.valid_polys.to_file(out_file, layer="clean_footprint")

        if not self.invalid_lines.empty:
            self.invalid_lines.to_file(out_file, layer="invalid_lines")

        if not self.invalid_polys.empty:
            self.invalid_polys.to_file(out_file, layer="invalid_polygons")

check_geom_validity()

Check MultiLineString and MultiPolygon in line and polygon dataframe.

Save to separate layers for user to double check

Source code in beratools/core/algo_line_grouping.py
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
def check_geom_validity(self):
    """
    Check MultiLineString and MultiPolygon in line and polygon dataframe.

    Save to separate layers for user to double check
    """
    #  remove null geometry
    # TODO make sure lines and polygons match in pairs
    # they should have same amount and spatial coverage
    self.valid_polys = self.polys[~self.polys.geometry.isna() & ~self.polys.geometry.is_empty]

    # save sh_geom.MultiLineString and sh_geom.MultiPolygon
    self.invalid_polys = self.polys[(self.polys.geometry.geom_type == "MultiPolygon")]

    # check lines
    self.valid_lines = self.merged_lines_trimmed[
        ~self.merged_lines_trimmed.geometry.isna() & ~self.merged_lines_trimmed.geometry.is_empty
    ]
    self.valid_lines.reset_index(inplace=True, drop=True)

    self.invalid_lines = self.merged_lines_trimmed[
        (self.merged_lines_trimmed.geometry.geom_type == "MultiLineString")
    ]
    self.invalid_lines.reset_index(inplace=True, drop=True)

run_regrouping()

Run this when new lines are added to grouped file.

Some new lines has empty group attributes

Source code in beratools/core/algo_line_grouping.py
797
798
799
800
801
802
803
def run_regrouping(self):
    """
    Run this when new lines are added to grouped file.

    Some new lines has empty group attributes
    """
    pass

update_line_in_vertex_node(line_id, line)

Update line in VertexNode after trimming.

Source code in beratools/core/algo_line_grouping.py
715
716
717
718
719
720
def update_line_in_vertex_node(self, line_id, line):
    """Update line in VertexNode after trimming."""
    idx = self.v_index.query(line)
    for i in idx:
        v = self.vertex_list[i]
        v.update_line(line_id, line)

PolygonTrimming dataclass

Store polygon and line to trim. Primary polygon is used to trim both.

Source code in beratools/core/algo_line_grouping.py
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
@dataclass
class PolygonTrimming:
    """Store polygon and line to trim. Primary polygon is used to trim both."""

    poly_primary: Optional[sh_geom.MultiPolygon] = None
    poly_index: int = field(default=-1)
    poly_cleanup: Optional[sh_geom.Polygon] = None
    line_index: int = field(default=-1)
    line_cleanup: Optional[sh_geom.LineString] = None

    def process(self, primary_poly_list=None, vertex=None):
        # prepare primary polygon
        poly_primary = shapely.union_all(primary_poly_list)
        trim_distance = TRIMMING_DISTANCE

        if self.line_cleanup.length < 100.0:
            trim_distance = 50.0

        poly_primary = poly_primary.intersection(vertex.buffer(trim_distance))

        self.poly_primary = poly_primary

        # TODO: check why there is such cases
        if self.poly_cleanup is None:
            print("No polygon to trim.")
            return

        midpoint = self.line_cleanup.interpolate(0.5, normalized=True)
        diff = self.poly_cleanup.difference(self.poly_primary)
        if diff.geom_type == "Polygon":
            self.poly_cleanup = diff
        elif diff.geom_type == "MultiPolygon":
            # area = self.poly_cleanup.area
            reserved = []
            for i in diff.geoms:
                # if i.area > TRIM_THRESHOLD * area:  # small part
                #     reserved.append(i)
                if i.contains(midpoint):
                    reserved.append(i)

            if len(reserved) == 0:
                pass
            elif len(reserved) == 1:
                self.poly_cleanup = sh_geom.Polygon(*reserved)
            else:
                # TODO output all MultiPolygons which should be dealt with
                # self.poly_cleanup = sh_geom.MultiPolygon(reserved)
                print("trim: MultiPolygon detected, please check")

        diff = self.line_cleanup.intersection(self.poly_cleanup)
        if diff.geom_type == "GeometryCollection":
            geoms = []
            for item in diff.geoms:
                if item.geom_type == "LineString":
                    geoms.append(item)
                elif item.geom_type == "MultiLineString":
                    print("trim: sh_geom.MultiLineString detected, please check")
            if len(geoms) == 0:
                return
            elif len(geoms) == 1:
                diff = geoms[0]
            else:
                diff = sh_geom.MultiLineString(geoms)

        if diff.geom_type == "LineString":
            self.line_cleanup = diff
        elif diff.geom_type == "MultiLineString":
            length = self.line_cleanup.length
            reserved = []
            for i in diff.geoms:
                if i.length > TRIM_THRESHOLD * length:  # small part
                    reserved.append(i)

            if len(reserved) == 0:
                pass
            elif len(reserved) == 1:
                self.line_cleanup = sh_geom.LineString(*reserved)
            else:
                # TODO output all MultiPolygons which should be dealt with
                self.poly_cleanup = sh_geom.MultiLineString(reserved)

SingleLine dataclass

Class to store line and its simplified line.

Source code in beratools/core/algo_line_grouping.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
@dataclass
class SingleLine:
    """Class to store line and its simplified line."""

    line_id: int = field(default=0)
    line: Union[sh_geom.LineString, sh_geom.MultiLineString] = field(default=None)
    sim_line: Union[sh_geom.LineString, sh_geom.MultiLineString] = field(default=None)
    vertex_index: int = field(default=0)
    group: int = field(default=0)

    def get_angle_for_line(self):
        return get_angle(self.sim_line, self.vertex_index)

    def end_transect(self):
        coords = self.sim_line.coords
        end_seg = None
        if self.vertex_index == 0:
            end_seg = sh_geom.LineString([coords[0], coords[1]])
        elif self.vertex_index == -1:
            end_seg = sh_geom.LineString([coords[-1], coords[-2]])

        l_left = end_seg.offset_curve(TRANSECT_LENGTH)
        l_right = end_seg.offset_curve(-TRANSECT_LENGTH)

        return sh_geom.LineString([l_left.coords[0], l_right.coords[0]])

    def midpoint(self):
        return shapely.force_2d(self.line.interpolate(0.5, normalized=True))

    def update_line(self, line):
        self.line = line

VertexClass

Bases: IntEnum

Enum class for vertex class.

Source code in beratools/core/algo_line_grouping.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
@enum.unique
class VertexClass(enum.IntEnum):
    """Enum class for vertex class."""

    TWO_WAY_ZERO_PRIMARY_LINE = 1
    THREE_WAY_ZERO_PRIMARY_LINE = 2
    THREE_WAY_ONE_PRIMARY_LINE = 3
    FOUR_WAY_ZERO_PRIMARY_LINE = 4
    FOUR_WAY_ONE_PRIMARY_LINE = 5
    FOUR_WAY_TWO_PRIMARY_LINE = 6
    FIVE_WAY_ZERO_PRIMARY_LINE = 7
    FIVE_WAY_ONE_PRIMARY_LINE = 8
    FIVE_WAY_TWO_PRIMARY_LINE = 9
    SINGLE_WAY = 10

VertexNode

Class to store vertex and lines connected to it.

Source code in beratools/core/algo_line_grouping.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
class VertexNode:
    """Class to store vertex and lines connected to it."""

    def __init__(self, line_id, line, sim_line, vertex_index, group=None) -> None:
        self.vertex = None
        self.line_list = []
        self.line_connected = []  # pairs of lines connected
        self.line_not_connected = []
        self.vertex_class = None

        if line:
            self.add_line(SingleLine(line_id, line, sim_line, vertex_index, group))

    def set_vertex(self, line, vertex_index):
        """Set vertex coordinates."""
        self.vertex = shapely.force_2d(shapely.get_point(line, vertex_index))

    def add_line(self, line_class):
        """Add line when creating or merging other VertexNode."""
        self.line_list.append(line_class)
        self.set_vertex(line_class.line, line_class.vertex_index)

    def get_line(self, line_id):
        for line in self.line_list:
            if line.line_id == line_id:
                return line.line

    def get_line_obj(self, line_id):
        for line in self.line_list:
            if line.line_id == line_id:
                return line

    def get_line_geom(self, line_id):
        return self.get_line_obj(line_id).line

    def get_all_line_ids(self):
        all_line_ids = {i.line_id for i in self.line_list}
        return all_line_ids

    def update_line(self, line_id, line):
        for i in self.line_list:
            if i.line_id == line_id:
                i.update_line(line)

    def merge(self, vertex):
        """Merge other VertexNode if they have same vertex coords."""
        self.add_line(vertex.line_list[0])

    def get_trim_transect(self, poly, line_indices):
        if not poly:
            return None

        internal_line = None
        for line_idx in line_indices:
            line = self.get_line_obj(line_idx)
            if poly.contains(line.midpoint()):
                internal_line = line

        if not internal_line:
            # print("No line is retrieved")
            return
        return internal_line.end_transect()

    def _trim_polygon(self, poly, trim_transect):
        if not poly or not trim_transect:
            return

        split_poly = shapely.ops.split(poly, trim_transect)

        if len(split_poly.geoms) != 2:
            return

        # check geom_type
        none_poly = False
        for geom in split_poly.geoms:
            if geom.geom_type != "Polygon":
                none_poly = True

        if none_poly:
            return

        # only two polygons in split_poly
        if split_poly.geoms[0].area > split_poly.geoms[1].area:
            poly = split_poly.geoms[0]
        else:
            poly = split_poly.geoms[1]

        return poly

    def trim_end_all(self, polys):
        """
        Trim all unconnected lines in the vertex.

        Args:
        polys: list of polygons returned by sindex.query

        """
        polys = polys.geometry
        new_polys = []
        for idx, poly in polys.items():
            out_poly = self.trim_end(poly)
            if out_poly:
                new_polys.append((idx, out_poly))

        return new_polys

    def trim_end(self, poly):
        transect = self.get_trim_transect(poly, self.line_not_connected)
        if not transect:
            return

        poly = self._trim_polygon(poly, transect)
        return poly
        # Helper to get the neighbor coordinate based on vertex_index.

    @staticmethod
    def get_vertex(line_obj, index):
        coords = list(line_obj.sim_line.coords)
        # Normalize negative indices.
        if index < 0:
            index += len(coords)
        if 0 <= index < len(coords):
            return sh_geom.Point(coords[index])

    @staticmethod
    def get_neighbor(line_obj):
        index = 0

        if line_obj.vertex_index == 0:
            index = 1
        elif line_obj.vertex_index == -1:
            index = -2

        return VertexNode.get_vertex(line_obj, index)

    @staticmethod
    def parallel_line_centered(p1, p2, center, length):
        """Generate a parallel line."""
        # Compute the direction vector.
        dx = p2.x - p1.x
        dy = p2.y - p1.y

        # Normalize the direction vector.
        magnitude = (dx**2 + dy**2) ** 0.5
        if magnitude == 0:
            return None
        dx /= magnitude
        dy /= magnitude

        # Compute half-length shifts.
        half_dx = (dx * length) / 2
        half_dy = (dy * length) / 2

        # Compute the endpoints of the new parallel line.
        new_p1 = sh_geom.Point(center.x - half_dx, center.y - half_dy)
        new_p2 = sh_geom.Point(center.x + half_dx, center.y + half_dy)

        return sh_geom.LineString([new_p1, new_p2])

    def get_transect_for_primary(self):
        """
        Get a transect line from two primary connected lines.

        This method calculates a transect line that is perpendicular to the line segment
        formed by the next vertex neighbors of these two lines and the current vertex.

        Return:
            A transect line object if the conditions are met, otherwise None.

        """
        if not self.line_connected or len(self.line_connected[0]) != 2:
            return None

        # Retrieve the two connected line objects from the first connectivity group.
        line_ids = self.line_connected[0]
        pt1 = None
        pt1 = None
        if line_ids[0] == line_ids[1]:  # line ring
            # TODO: check line ring when merging vertex nodes.
            # TODO: change one end index to -1
            line_id = line_ids[0]
            pt1 = self.get_vertex(self.get_line_obj(line_id), 1)
            pt2 = self.get_vertex(self.get_line_obj(line_id), -2)
        else:  # two different lines
            line_obj1 = self.get_line_obj(line_ids[0])
            line_obj2 = self.get_line_obj(line_ids[1])

            pt1 = self.get_neighbor(line_obj1)
            pt2 = self.get_neighbor(line_obj2)

        if pt1 is None or pt2 is None:
            return None

        transect = algo_common.generate_perpendicular_line_precise([pt1, self.vertex, pt2], offset=40)
        return transect

    def get_transect_for_primary_second(self):
        """
        Get a transect line from the second primary connected line.

        For the second primary line, this method retrieves the neighbor point from
        two lines in the second connectivity group, creates a reference line through the
        vertex by mirroring the neighbor point about the vertex, and then generates a
        parallel line centered at the vertex.

        Returns:
            A LineString representing the transect if available, otherwise None.

        """
        # Ensure there is a second connectivity group.
        if not self.line_connected or len(self.line_connected) < 2:
            return None

        # Use the first line of the second connectivity group.
        second_primary = self.line_connected[1]
        line_obj1 = self.get_line_obj(second_primary[0])
        line_obj2 = self.get_line_obj(second_primary[1])
        if not line_obj1 or not line_obj2:
            return None

        pt1 = self.get_neighbor(line_obj1)
        pt2 = self.get_neighbor(line_obj2)

        if pt1 is None or pt2 is None:
            return None

        center = self.vertex
        transect = self.parallel_line_centered(pt1, pt2, center, TRANSECT_LENGTH)
        return transect

    def trim_primary_end(self, polys):
        """
        Trim first primary line in the vertex.

        Args:
        polys: list of polygons returned by sindex.query

        """
        if len(self.line_connected) == 0:
            return

        new_polys = []
        line = self.line_connected[0]

        # use the first line to get transect
        # transect = self.get_line_obj(line[0]).end_transect()
        # if len(self.line_connected) == 1:
        transect = self.get_transect_for_primary()
        # elif len(self.line_connected) > 1:
        #     transect = self.get_transect_for_primary_second()

        idx_1 = line[0]
        poly_1 = None
        idx_1 = line[1]
        poly_2 = None

        for idx, poly in polys.items():
            # TODO: no polygons
            if not poly:
                continue

            if poly.buffer(SMALL_BUFFER).contains(self.get_line_geom(line[0])):
                poly_1 = poly
                idx_1 = idx
            elif poly.buffer(SMALL_BUFFER).contains(self.get_line_geom(line[1])):
                poly_2 = poly
                idx_2 = idx

        if poly_1:
            poly_1 = self._trim_polygon(poly_1, transect)
            new_polys.append([idx_1, poly_1])
        if poly_2:
            poly_2 = self._trim_polygon(poly_2, transect)
            new_polys.append([idx_2, poly_2])

        return new_polys

    def trim_intersection(self, polys, merge_group=True):
        """
        Trim intersection of lines and polygons.

        TODO: there are polygons of 0 zero.

        """

        def get_poly_with_info(line, polys):
            if polys.empty:
                return None, None, None

            for idx, row in polys.iterrows():
                poly = row.geometry
                if not poly:  # TODO: no polygon
                    continue

                if poly.buffer(SMALL_BUFFER).contains(line):
                    return idx, poly, row["max_width"]

            return None, None, None

        poly_trim_list = []
        primary_lines = []
        p_primary_list = []

        # retrieve primary lines
        if len(self.line_connected) > 0:
            for idx in self.line_connected[0]:  # only one connected line is used
                primary_lines.append(self.get_line(idx))
                _, poly, _ = get_poly_with_info(self.get_line(idx), polys)

                if poly:
                    p_primary_list.append(poly.buffer(bt_const.SMALL_BUFFER))
                else:
                    print("trim_intersection: No primary polygon found.")

        line_idx_to_trim = self.line_not_connected
        poly_list = []
        if not merge_group:  # add all remaining primary lines for trimming
            if len(self.line_connected) > 1:
                for line in self.line_connected[1:]:
                    line_idx_to_trim.extend(line)

            # sort line index to by footprint area
            for line_idx in line_idx_to_trim:
                line = self.get_line_geom(line_idx)
                poly_idx, poly, max_width = get_poly_with_info(line, polys)
                if poly_idx:
                    poly_list.append((line_idx, poly_idx, max_width))

            poly_list = sorted(poly_list, key=lambda x: x[2])

        # create PolygonTrimming object and trim all by primary line
        for i, indices in enumerate(poly_list):
            line_idx = indices[0]
            poly_idx = indices[1]
            line_cleanup = self.get_line(line_idx)
            poly_cleanup = polys.loc[poly_idx].geometry
            poly_trim = PolygonTrimming(
                line_index=line_idx,
                line_cleanup=line_cleanup,
                poly_index=poly_idx,
                poly_cleanup=poly_cleanup,
            )

            poly_trim_list.append(poly_trim)
            if p_primary_list:
                poly_trim.process(p_primary_list, self.vertex)

            # use poly_trim.poly_cleanup to update polys gdf's geometry
            polys.at[poly_trim.poly_index, "geometry"] = poly_trim.poly_cleanup

        # further trimming overlaps by non-primary lines
        # poly_list and poly_trim_list have same index
        for i, indices in enumerate(poly_list):
            p_list = []
            for p in poly_list[i + 1 :]:
                p_list.append(polys.loc[p[1]].geometry)

            poly_trim = poly_trim_list[i]
            poly_trim.process(p_list, self.vertex)

        return poly_trim_list

    def assign_vertex_class(self):
        if len(self.line_list) == 5:
            if len(self.line_connected) == 0:
                self.vertex_class = VertexClass.FIVE_WAY_ZERO_PRIMARY_LINE
            if len(self.line_connected) == 1:
                self.vertex_class = VertexClass.FIVE_WAY_ONE_PRIMARY_LINE
            if len(self.line_connected) == 2:
                self.vertex_class = VertexClass.FIVE_WAY_TWO_PRIMARY_LINE
        elif len(self.line_list) == 4:
            if len(self.line_connected) == 0:
                self.vertex_class = VertexClass.FOUR_WAY_ZERO_PRIMARY_LINE
            if len(self.line_connected) == 1:
                self.vertex_class = VertexClass.FOUR_WAY_ONE_PRIMARY_LINE
            if len(self.line_connected) == 2:
                self.vertex_class = VertexClass.FOUR_WAY_TWO_PRIMARY_LINE
        elif len(self.line_list) == 3:
            if len(self.line_connected) == 0:
                self.vertex_class = VertexClass.THREE_WAY_ZERO_PRIMARY_LINE
            if len(self.line_connected) == 1:
                self.vertex_class = VertexClass.THREE_WAY_ONE_PRIMARY_LINE
        elif len(self.line_list) == 2:
            if len(self.line_connected) == 0:
                self.vertex_class = VertexClass.TWO_WAY_ZERO_PRIMARY_LINE
        elif len(self.line_list) == 1:
            self.vertex_class = VertexClass.SINGLE_WAY

    def all_has_valid_group_attr(self):
        """If all values in group list are valid value, return True."""
        # TODO: if some line has no group, give advice
        for i in self.line_list:
            if i.group is None:
                return False

        return True

    def need_regrouping(self):
        pass

    def check_connectivity(self, use_angle_grouping=True):
        # Fill missing group with -1
        for line in self.line_list:
            if line.group is None:
                line.group = -1

        if self.need_regrouping():
            self.group_regroup()

        if use_angle_grouping:
            self.group_line_by_angle()
        else:
            self.update_connectivity_by_group()

        # record line not connected
        all_line_ids = self.get_all_line_ids()
        self.line_not_connected = list(all_line_ids - set(chain(*self.line_connected)))

        self.assign_vertex_class()

    def group_regroup(self):
        pass

    def update_connectivity_by_group(self):
        group_line = defaultdict(list)
        for i in self.line_list:
            group_line[i.group].append(i.line_id)

        for value in group_line.values():
            if len(value) > 1:
                self.line_connected.append(value)

    def group_line_by_angle(self):
        """Generate connectivity of all lines."""
        if len(self.line_list) == 1:
            return

        # if there are 2 and more lines
        new_angles = [i.get_angle_for_line() for i in self.line_list]
        angle_visited = [False] * len(new_angles)

        if len(self.line_list) == 2:
            angle_diff = abs(new_angles[0] - new_angles[1])
            angle_diff = angle_diff if angle_diff <= np.pi else angle_diff - np.pi

            # if angle_diff >= TURN_ANGLE_TOLERANCE:
            self.line_connected.append(
                (
                    self.line_list[0].line_id,
                    self.line_list[1].line_id,
                )
            )
            return

        # three and more lines
        for i, angle_1 in enumerate(new_angles):
            for j, angle_2 in enumerate(new_angles[i + 1 :]):
                if not angle_visited[i + j + 1]:
                    angle_diff = abs(angle_1 - angle_2)
                    angle_diff = angle_diff if angle_diff <= np.pi else angle_diff - np.pi
                    if (
                        angle_diff < ANGLE_TOLERANCE
                        or np.pi - ANGLE_TOLERANCE < abs(angle_1 - angle_2) < np.pi + ANGLE_TOLERANCE
                    ):
                        angle_visited[j + i + 1] = True  # tenth of PI
                        self.line_connected.append(
                            (
                                self.line_list[i].line_id,
                                self.line_list[i + j + 1].line_id,
                            )
                        )

add_line(line_class)

Add line when creating or merging other VertexNode.

Source code in beratools/core/algo_line_grouping.py
159
160
161
162
def add_line(self, line_class):
    """Add line when creating or merging other VertexNode."""
    self.line_list.append(line_class)
    self.set_vertex(line_class.line, line_class.vertex_index)

all_has_valid_group_attr()

If all values in group list are valid value, return True.

Source code in beratools/core/algo_line_grouping.py
530
531
532
533
534
535
536
537
def all_has_valid_group_attr(self):
    """If all values in group list are valid value, return True."""
    # TODO: if some line has no group, give advice
    for i in self.line_list:
        if i.group is None:
            return False

    return True

get_transect_for_primary()

Get a transect line from two primary connected lines.

This method calculates a transect line that is perpendicular to the line segment formed by the next vertex neighbors of these two lines and the current vertex.

Return

A transect line object if the conditions are met, otherwise None.

Source code in beratools/core/algo_line_grouping.py
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
def get_transect_for_primary(self):
    """
    Get a transect line from two primary connected lines.

    This method calculates a transect line that is perpendicular to the line segment
    formed by the next vertex neighbors of these two lines and the current vertex.

    Return:
        A transect line object if the conditions are met, otherwise None.

    """
    if not self.line_connected or len(self.line_connected[0]) != 2:
        return None

    # Retrieve the two connected line objects from the first connectivity group.
    line_ids = self.line_connected[0]
    pt1 = None
    pt1 = None
    if line_ids[0] == line_ids[1]:  # line ring
        # TODO: check line ring when merging vertex nodes.
        # TODO: change one end index to -1
        line_id = line_ids[0]
        pt1 = self.get_vertex(self.get_line_obj(line_id), 1)
        pt2 = self.get_vertex(self.get_line_obj(line_id), -2)
    else:  # two different lines
        line_obj1 = self.get_line_obj(line_ids[0])
        line_obj2 = self.get_line_obj(line_ids[1])

        pt1 = self.get_neighbor(line_obj1)
        pt2 = self.get_neighbor(line_obj2)

    if pt1 is None or pt2 is None:
        return None

    transect = algo_common.generate_perpendicular_line_precise([pt1, self.vertex, pt2], offset=40)
    return transect

get_transect_for_primary_second()

Get a transect line from the second primary connected line.

For the second primary line, this method retrieves the neighbor point from two lines in the second connectivity group, creates a reference line through the vertex by mirroring the neighbor point about the vertex, and then generates a parallel line centered at the vertex.

Returns:

Type Description

A LineString representing the transect if available, otherwise None.

Source code in beratools/core/algo_line_grouping.py
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
def get_transect_for_primary_second(self):
    """
    Get a transect line from the second primary connected line.

    For the second primary line, this method retrieves the neighbor point from
    two lines in the second connectivity group, creates a reference line through the
    vertex by mirroring the neighbor point about the vertex, and then generates a
    parallel line centered at the vertex.

    Returns:
        A LineString representing the transect if available, otherwise None.

    """
    # Ensure there is a second connectivity group.
    if not self.line_connected or len(self.line_connected) < 2:
        return None

    # Use the first line of the second connectivity group.
    second_primary = self.line_connected[1]
    line_obj1 = self.get_line_obj(second_primary[0])
    line_obj2 = self.get_line_obj(second_primary[1])
    if not line_obj1 or not line_obj2:
        return None

    pt1 = self.get_neighbor(line_obj1)
    pt2 = self.get_neighbor(line_obj2)

    if pt1 is None or pt2 is None:
        return None

    center = self.vertex
    transect = self.parallel_line_centered(pt1, pt2, center, TRANSECT_LENGTH)
    return transect

group_line_by_angle()

Generate connectivity of all lines.

Source code in beratools/core/algo_line_grouping.py
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
def group_line_by_angle(self):
    """Generate connectivity of all lines."""
    if len(self.line_list) == 1:
        return

    # if there are 2 and more lines
    new_angles = [i.get_angle_for_line() for i in self.line_list]
    angle_visited = [False] * len(new_angles)

    if len(self.line_list) == 2:
        angle_diff = abs(new_angles[0] - new_angles[1])
        angle_diff = angle_diff if angle_diff <= np.pi else angle_diff - np.pi

        # if angle_diff >= TURN_ANGLE_TOLERANCE:
        self.line_connected.append(
            (
                self.line_list[0].line_id,
                self.line_list[1].line_id,
            )
        )
        return

    # three and more lines
    for i, angle_1 in enumerate(new_angles):
        for j, angle_2 in enumerate(new_angles[i + 1 :]):
            if not angle_visited[i + j + 1]:
                angle_diff = abs(angle_1 - angle_2)
                angle_diff = angle_diff if angle_diff <= np.pi else angle_diff - np.pi
                if (
                    angle_diff < ANGLE_TOLERANCE
                    or np.pi - ANGLE_TOLERANCE < abs(angle_1 - angle_2) < np.pi + ANGLE_TOLERANCE
                ):
                    angle_visited[j + i + 1] = True  # tenth of PI
                    self.line_connected.append(
                        (
                            self.line_list[i].line_id,
                            self.line_list[i + j + 1].line_id,
                        )
                    )

merge(vertex)

Merge other VertexNode if they have same vertex coords.

Source code in beratools/core/algo_line_grouping.py
186
187
188
def merge(self, vertex):
    """Merge other VertexNode if they have same vertex coords."""
    self.add_line(vertex.line_list[0])

parallel_line_centered(p1, p2, center, length) staticmethod

Generate a parallel line.

Source code in beratools/core/algo_line_grouping.py
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
@staticmethod
def parallel_line_centered(p1, p2, center, length):
    """Generate a parallel line."""
    # Compute the direction vector.
    dx = p2.x - p1.x
    dy = p2.y - p1.y

    # Normalize the direction vector.
    magnitude = (dx**2 + dy**2) ** 0.5
    if magnitude == 0:
        return None
    dx /= magnitude
    dy /= magnitude

    # Compute half-length shifts.
    half_dx = (dx * length) / 2
    half_dy = (dy * length) / 2

    # Compute the endpoints of the new parallel line.
    new_p1 = sh_geom.Point(center.x - half_dx, center.y - half_dy)
    new_p2 = sh_geom.Point(center.x + half_dx, center.y + half_dy)

    return sh_geom.LineString([new_p1, new_p2])

set_vertex(line, vertex_index)

Set vertex coordinates.

Source code in beratools/core/algo_line_grouping.py
155
156
157
def set_vertex(self, line, vertex_index):
    """Set vertex coordinates."""
    self.vertex = shapely.force_2d(shapely.get_point(line, vertex_index))

trim_end_all(polys)

Trim all unconnected lines in the vertex.

Args: polys: list of polygons returned by sindex.query

Source code in beratools/core/algo_line_grouping.py
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
def trim_end_all(self, polys):
    """
    Trim all unconnected lines in the vertex.

    Args:
    polys: list of polygons returned by sindex.query

    """
    polys = polys.geometry
    new_polys = []
    for idx, poly in polys.items():
        out_poly = self.trim_end(poly)
        if out_poly:
            new_polys.append((idx, out_poly))

    return new_polys

trim_intersection(polys, merge_group=True)

Trim intersection of lines and polygons.

TODO: there are polygons of 0 zero.

Source code in beratools/core/algo_line_grouping.py
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
def trim_intersection(self, polys, merge_group=True):
    """
    Trim intersection of lines and polygons.

    TODO: there are polygons of 0 zero.

    """

    def get_poly_with_info(line, polys):
        if polys.empty:
            return None, None, None

        for idx, row in polys.iterrows():
            poly = row.geometry
            if not poly:  # TODO: no polygon
                continue

            if poly.buffer(SMALL_BUFFER).contains(line):
                return idx, poly, row["max_width"]

        return None, None, None

    poly_trim_list = []
    primary_lines = []
    p_primary_list = []

    # retrieve primary lines
    if len(self.line_connected) > 0:
        for idx in self.line_connected[0]:  # only one connected line is used
            primary_lines.append(self.get_line(idx))
            _, poly, _ = get_poly_with_info(self.get_line(idx), polys)

            if poly:
                p_primary_list.append(poly.buffer(bt_const.SMALL_BUFFER))
            else:
                print("trim_intersection: No primary polygon found.")

    line_idx_to_trim = self.line_not_connected
    poly_list = []
    if not merge_group:  # add all remaining primary lines for trimming
        if len(self.line_connected) > 1:
            for line in self.line_connected[1:]:
                line_idx_to_trim.extend(line)

        # sort line index to by footprint area
        for line_idx in line_idx_to_trim:
            line = self.get_line_geom(line_idx)
            poly_idx, poly, max_width = get_poly_with_info(line, polys)
            if poly_idx:
                poly_list.append((line_idx, poly_idx, max_width))

        poly_list = sorted(poly_list, key=lambda x: x[2])

    # create PolygonTrimming object and trim all by primary line
    for i, indices in enumerate(poly_list):
        line_idx = indices[0]
        poly_idx = indices[1]
        line_cleanup = self.get_line(line_idx)
        poly_cleanup = polys.loc[poly_idx].geometry
        poly_trim = PolygonTrimming(
            line_index=line_idx,
            line_cleanup=line_cleanup,
            poly_index=poly_idx,
            poly_cleanup=poly_cleanup,
        )

        poly_trim_list.append(poly_trim)
        if p_primary_list:
            poly_trim.process(p_primary_list, self.vertex)

        # use poly_trim.poly_cleanup to update polys gdf's geometry
        polys.at[poly_trim.poly_index, "geometry"] = poly_trim.poly_cleanup

    # further trimming overlaps by non-primary lines
    # poly_list and poly_trim_list have same index
    for i, indices in enumerate(poly_list):
        p_list = []
        for p in poly_list[i + 1 :]:
            p_list.append(polys.loc[p[1]].geometry)

        poly_trim = poly_trim_list[i]
        poly_trim.process(p_list, self.vertex)

    return poly_trim_list

trim_primary_end(polys)

Trim first primary line in the vertex.

Args: polys: list of polygons returned by sindex.query

Source code in beratools/core/algo_line_grouping.py
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
def trim_primary_end(self, polys):
    """
    Trim first primary line in the vertex.

    Args:
    polys: list of polygons returned by sindex.query

    """
    if len(self.line_connected) == 0:
        return

    new_polys = []
    line = self.line_connected[0]

    # use the first line to get transect
    # transect = self.get_line_obj(line[0]).end_transect()
    # if len(self.line_connected) == 1:
    transect = self.get_transect_for_primary()
    # elif len(self.line_connected) > 1:
    #     transect = self.get_transect_for_primary_second()

    idx_1 = line[0]
    poly_1 = None
    idx_1 = line[1]
    poly_2 = None

    for idx, poly in polys.items():
        # TODO: no polygons
        if not poly:
            continue

        if poly.buffer(SMALL_BUFFER).contains(self.get_line_geom(line[0])):
            poly_1 = poly
            idx_1 = idx
        elif poly.buffer(SMALL_BUFFER).contains(self.get_line_geom(line[1])):
            poly_2 = poly
            idx_2 = idx

    if poly_1:
        poly_1 = self._trim_polygon(poly_1, transect)
        new_polys.append([idx_1, poly_1])
    if poly_2:
        poly_2 = self._trim_polygon(poly_2, transect)
        new_polys.append([idx_2, poly_2])

    return new_polys

get_angle(line, end_index)

Calculate the angle of the first or last segment.

Args: line: sh_geom.LineString end_index: 0 or -1 of the line vertices. Consider the multipart.

Source code in beratools/core/algo_line_grouping.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def get_angle(line, end_index):
    """
    Calculate the angle of the first or last segment.

    Args:
    line: sh_geom.LineString
    end_index: 0 or -1 of the line vertices. Consider the multipart.

    """
    pts = points_in_line(line)

    if end_index == 0:
        pt_1 = pts[0]
        pt_2 = pts[1]
    elif end_index == -1:
        pt_1 = pts[-1]
        pt_2 = pts[-2]

    delta_x = pt_2.x - pt_1.x
    delta_y = pt_2.y - pt_1.y
    angle = np.arctan2(delta_y, delta_x)

    return angle

points_in_line(line)

Get point list of line.

Source code in beratools/core/algo_line_grouping.py
70
71
72
73
74
75
76
77
78
79
80
81
def points_in_line(line):
    """Get point list of line."""
    point_list = []
    try:
        for point in list(line.coords):  # loops through every point in a line
            # loops through every vertex of every segment
            if point:  # adds all the vertices to segment_list, which creates an array
                point_list.append(sh_geom.Point(point[0], point[1]))
    except Exception as e:
        print(e)

    return point_list

Copyright (C) 2025 Applied Geospatial Research Group.

This script is licensed under the GNU General Public License v3.0. See https://gnu.org/licenses/gpl-3.0 for full license details.

Author: Richard Zeng

Description

This script is part of the BERA Tools. Webpage: https://github.com/appliedgrg/beratools

This file is intended to be hosting algorithms and utility functions/classes for merging lines.

MergeLines

Merge line segments in MultiLineString.

Source code in beratools/core/algo_merge_lines.py
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
class MergeLines:
    """Merge line segments in MultiLineString."""

    def __init__(self, multi_line):
        self.G = None
        self.line_segs = None
        self.multi_line = multi_line
        self.node_poly = None
        self.end = None

        self.create_graph()

    def create_graph(self):
        self.line_segs = list(self.multi_line.geoms)

        # TODO: check empty line and null geoms
        self.line_segs = [line for line in self.line_segs if line.length > 1e-3]
        self.multi_line = sh_geom.MultiLineString(self.line_segs)
        m = sh_geom.mapping(self.multi_line)
        self.end = [(i[0], i[-1]) for i in m["coordinates"]]

        self.G = nk.Graph(edgesIndexed=True)
        self.G.addNodes(2)
        self.G.addEdge(0, 1)

        self.node_poly = [
            sh_geom.Point(self.end[0][0]).buffer(1),
            sh_geom.Point(self.end[0][1]).buffer(1),
        ]

        for i, line in enumerate(self.end[1:]):
            node_exists = False
            pt = sh_geom.Point(line[0])
            pt_buffer = pt.buffer(1)

            for node in self.G.iterNodes():
                if self.node_poly[node].contains(pt):
                    node_exists = True
                    node_start = node
            if not node_exists:
                node_start = self.G.addNode()
                self.node_poly.append(pt_buffer)

            node_exists = False
            pt = sh_geom.Point(line[1])
            pt_buffer = pt.buffer(1)
            for node in self.G.iterNodes():
                if self.node_poly[node].contains(pt):
                    node_exists = True
                    node_end = node
            if not node_exists:
                node_end = self.G.addNode()
                self.node_poly.append(pt_buffer)

            self.G.addEdge(node_start, node_end)

    def get_components(self):
        cc = nk.components.ConnectedComponents(self.G)
        cc.run()
        components = cc.getComponents()
        return components

    def is_single_path(self, component):
        single_path = True
        for node in component:
            neighbors = list(self.G.iterNeighbors(node))
            if len(neighbors) > 2:
                single_path = False

        return single_path

    def get_merged_line_for_component(self, component):
        sub = nk.graphtools.subgraphFromNodes(self.G, component)
        lines = None
        if nk.graphtools.maxDegree(sub) >= 3:  # not simple path
            edges = [self.G.edgeId(i[0], i[1]) for i in list(sub.iterEdges())]
            lines = itemgetter(*edges)(self.line_segs)
        elif nk.graphtools.maxDegree(sub) == 2:
            lines = self.merge_single_line(component)

        return lines

    def find_path_for_component(self, component):
        neighbors = list(self.G.iterNeighbors(component[0]))
        path = [component[0]]
        right = neighbors[0]
        path.append(right)

        left = None
        if len(neighbors) == 2:
            left = neighbors[1]
            path.insert(0, left)

        neighbors = list(self.G.iterNeighbors(right))
        while len(neighbors) > 1:
            if neighbors[0] not in path:
                path.append(neighbors[0])
                right = neighbors[0]
            else:
                path.append(neighbors[1])
                right = neighbors[1]

            neighbors = list(self.G.iterNeighbors(right))

        # last node
        if neighbors[0] not in path:
            path.append(neighbors[0])

        # process left side
        if left:
            neighbors = list(self.G.iterNeighbors(left))
            while len(neighbors) > 1:
                if neighbors[0] not in path:
                    path.insert(0, neighbors[0])
                    left = neighbors[0]
                else:
                    path.insert(0, neighbors[1])
                    left = neighbors[1]

                neighbors = list(self.G.iterNeighbors(left))

            # last node
            if neighbors[0] not in path:
                path.insert(0, neighbors[0])

        return path

    def merge_single_line(self, component):
        path = self.find_path_for_component(component)

        pairs = list(pairwise(path))
        line_list = [self.G.edgeId(i[0], i[1]) for i in pairs]

        vertices = []

        for i, id in enumerate(line_list):
            pair = pairs[i]
            poly_t = self.node_poly[pair[0]]
            point_t = sh_geom.Point(self.end[id][0])
            if poly_t.contains(point_t):
                line = self.line_segs[id]
            else:
                # line = reverse(self.line_segs[id])
                line = self.line_segs[id].reverse()

            vertices.extend(list(line.coords))
            last_vertex = vertices.pop()

        vertices.append(last_vertex)
        merged_line = sh_geom.LineString(vertices)

        return [merged_line]

    def merge_all_lines(self):
        components = self.get_components()
        lines = []
        for c in components:
            line = self.get_merged_line_for_component(c)
            if line:
                lines.extend(self.get_merged_line_for_component(c))
            else:  # TODO: check line
                print(f"merge_all_lines: failed to merge: {self.multi_line.bounds}")

        # print('Merge lines done.')

        if len(lines) > 1:
            return sh_geom.MultiLineString(lines)
        elif len(lines) == 1:
            return lines[0]
        else:
            return None

Split lines at intersections using a class-based approach.

LineSplitter

Split lines at intersections.

Source code in beratools/core/algo_split_with_lines.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
class LineSplitter:
    """Split lines at intersections."""

    def __init__(self, line_gdf):
        """
        Initialize the LineSplitter with the input GeoPackage and layer name.

        Args:
        input_gpkg (str): Path to the input GeoPackage file.
        layer_name (str): Name of the layer to read from the GeoPackage.

        """
        # Explode if needed for multi-part geometries
        self.line_gdf = line_gdf.explode()
        self.line_gdf[INTER_STATUS_COL] = 1  # record line intersection status
        self.inter_status = {}
        self.sindex = self.line_gdf.sindex  # Spatial index for faster operations

        self.intersection_gdf = []
        self.split_lines_gdf = None

    def cut_line_by_points(self, line, points):
        """
        Cuts a LineString into segments based on the given points.

        Args:
        line: A shapely LineString to be cut.
        points: A list of Point objects where the LineString needs to be cut.

        Return:
        A list of LineString segments after the cuts.

        """
        # Create a spatial index for the coordinates of the LineString
        line_coords = [Point(x, y) for x, y in line.coords]
        sindex = STRtree(line_coords)

        # Sort points based on their projected position along the line
        sorted_points = sorted(points, key=lambda p: line.project(p))
        segments = []

        # Process each point, inserting it into the correct location
        start_idx = 0
        start_pt = None
        end_pt = None

        for point in sorted_points:
            # Find the closest segment on the line using the spatial index
            nearest_pt_idx = sindex.nearest(point)
            end_idx = nearest_pt_idx
            end_pt = point

            dist1 = line.project(point)
            dist2 = line.project(line_coords[nearest_pt_idx])

            if dist1 > dist2:
                end_idx = nearest_pt_idx + 1

            # Create a new segment
            new_coords = line_coords[start_idx:end_idx]
            if start_pt:  # Append start point
                new_coords = [start_pt] + new_coords

            if end_pt:  # Append end point
                new_coords = new_coords + [end_pt]

            nearest_segment = LineString(new_coords)
            start_idx = end_idx
            start_pt = end_pt

            segments.append(nearest_segment)

        # Add remaining part of the line after the last point
        if start_idx < len(line_coords):
            # If last point is not close to end point of line
            if start_pt.distance(line_coords[-1]) > EPSILON:
                remaining_part = LineString([start_pt] + line_coords[end_idx:])
                segments.append(remaining_part)

        return segments

    def find_intersections(self):
        """
        Find intersections between lines in the GeoDataFrame.

        Return:
        List of Point geometries where the lines intersect.

        """
        visited_pairs = set()
        intersection_points = []

        # Iterate through each line geometry to find intersections
        for idx, line1 in enumerate(self.line_gdf.geometry):
            # Use spatial index to find candidates for intersection
            indices = list(self.sindex.intersection(line1.bounds))
            indices.remove(idx)  # Remove the current index from the list

            for match_idx in indices:
                line2 = self.line_gdf.iloc[match_idx].geometry

                # Create an index pair where the smaller index comes first
                pair = tuple(sorted([idx, match_idx]))

                # Skip if this pair has already been visited
                if pair in visited_pairs:
                    continue

                # Mark the pair as visited
                visited_pairs.add(pair)

                # Only check lines that are different and intersect
                line1 = snap(line1, line2, tolerance=EPSILON)
                if line1.intersects(line2):
                    # Find intersection points (can be multiple)
                    intersections = line1.intersection(line2)

                    if intersections.is_empty:
                        continue

                    # Intersection can be Point, MultiPoint, LineString
                    # or GeometryCollection
                    if isinstance(intersections, Point):
                        intersection_points.append(intersections)
                    else:
                        # record for further inspection
                        # GeometryCollection, MultiLineString
                        if isinstance(intersections, MultiPoint):
                            intersection_points.extend(intersections.geoms)
                        elif isinstance(intersections, LineString):
                            intersection_points.append(intersections.interpolate(0.5, normalized=True))

                        # if minimum distance between points is greater than threshold
                        # mark line as valid
                        if isinstance(intersections, MultiPoint):
                            if min_distance_in_multipoint(intersections) > algo_common.DISTANCE_THRESHOLD:
                                continue
                        # if intersection is a line, mark line as valid
                        if isinstance(intersections, LineString):
                            continue

                        for item in pair:
                            self.inter_status[item] = 0

        self.intersection_gdf = gpd.GeoDataFrame(geometry=intersection_points, crs=self.line_gdf.crs)

    def split_lines_at_intersections(self):
        """
        Split lines at the given intersection points.

        Args:
        intersection_points: List of Point geometries where the lines should be split.

        Returns:
        A GeoDataFrame with the split lines.

        """
        # Create a spatial index for faster point-line intersection checks
        sindex = self.intersection_gdf.sindex

        # List to hold the new split line segments
        new_rows = []

        # Iterate through each intersection point to split lines at that point
        for row in self.line_gdf.itertuples():
            if not isinstance(row.geometry, LineString):
                continue

            # Use spatial index to find possible line candidates for intersection
            possible_matches = sindex.query(row.geometry.buffer(EPSILON))
            end_pts = MultiPoint([row.geometry.coords[0], row.geometry.coords[-1]])

            pt_list = []
            new_segments = [row.geometry]

            for idx in possible_matches:
                point = self.intersection_gdf.loc[idx].geometry
                # Check if the point is on the line
                if row.geometry.distance(point) < EPSILON:
                    if end_pts.distance(point) < EPSILON:
                        continue
                    else:
                        pt_list.append(point)

            if len(pt_list) > 0:
                # Split the line at the intersection
                new_segments = self.cut_line_by_points(row.geometry, pt_list)

            # If the line was split into multiple segments, create new rows
            for segment in new_segments:
                new_row = row._asdict()  # Convert the original row into a dictionary
                new_row["geometry"] = segment  # Update the geometry with the split one
                new_rows.append(new_row)

        self.split_lines_gdf = gpd.GeoDataFrame(
            new_rows, columns=self.line_gdf.columns, crs=self.line_gdf.crs
        )

        self.split_lines_gdf = algo_common.clean_line_geometries(self.split_lines_gdf)

        # Debugging: print how many segments were created
        print(f"Total new line segments created: {len(new_rows)}")

    def save_to_geopackage(
        self,
        input_gpkg,
        line_layer="split_lines",
        intersection_layer=None,
        invalid_layer=None,
    ):
        """
        Save the split lines and intersection points to the GeoPackage.

        Args:
        line_layer: split lines layer name in the GeoPackage.
        intersection_layer: layer name for intersection points in the GeoPackage.

        """
        # Save intersection points and split lines to the GeoPackage
        if self.split_lines_gdf is not None and intersection_layer:
            if len(self.intersection_gdf) > 0:
                self.intersection_gdf.to_file(input_gpkg, layer=intersection_layer, driver="GPKG")

        if self.split_lines_gdf is not None and line_layer:
            if len(self.split_lines_gdf) > 0:
                self.split_lines_gdf["length"] = self.split_lines_gdf.geometry.length
                self.split_lines_gdf.to_file(input_gpkg, layer=line_layer, driver="GPKG")

        # save invalid splits
        invalid_splits = self.line_gdf.loc[self.line_gdf[INTER_STATUS_COL] == 0]
        if not invalid_splits.empty and invalid_layer:
            if len(invalid_splits) > 0:
                invalid_splits.to_file(input_gpkg, layer=invalid_layer, driver="GPKG")

    def process(self, intersection_gdf=None):
        """
        Find intersection points, split lines at intersections.

        Args:
        intersection_gdf: external GeoDataFrame with intersection points.

        """
        if intersection_gdf is not None:
            self.intersection_gdf = intersection_gdf
        else:
            self.find_intersections()

        if self.inter_status:
            for idx in self.inter_status.keys():
                self.line_gdf.loc[idx, INTER_STATUS_COL] = self.inter_status[idx]

        if not self.intersection_gdf.empty:
            # Split the lines at intersection points
            self.split_lines_at_intersections()
        else:
            print("No intersection points found, no lines to split.")

__init__(line_gdf)

Initialize the LineSplitter with the input GeoPackage and layer name.

Args: input_gpkg (str): Path to the input GeoPackage file. layer_name (str): Name of the layer to read from the GeoPackage.

Source code in beratools/core/algo_split_with_lines.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def __init__(self, line_gdf):
    """
    Initialize the LineSplitter with the input GeoPackage and layer name.

    Args:
    input_gpkg (str): Path to the input GeoPackage file.
    layer_name (str): Name of the layer to read from the GeoPackage.

    """
    # Explode if needed for multi-part geometries
    self.line_gdf = line_gdf.explode()
    self.line_gdf[INTER_STATUS_COL] = 1  # record line intersection status
    self.inter_status = {}
    self.sindex = self.line_gdf.sindex  # Spatial index for faster operations

    self.intersection_gdf = []
    self.split_lines_gdf = None

cut_line_by_points(line, points)

Cuts a LineString into segments based on the given points.

Args: line: A shapely LineString to be cut. points: A list of Point objects where the LineString needs to be cut.

Return: A list of LineString segments after the cuts.

Source code in beratools/core/algo_split_with_lines.py
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def cut_line_by_points(self, line, points):
    """
    Cuts a LineString into segments based on the given points.

    Args:
    line: A shapely LineString to be cut.
    points: A list of Point objects where the LineString needs to be cut.

    Return:
    A list of LineString segments after the cuts.

    """
    # Create a spatial index for the coordinates of the LineString
    line_coords = [Point(x, y) for x, y in line.coords]
    sindex = STRtree(line_coords)

    # Sort points based on their projected position along the line
    sorted_points = sorted(points, key=lambda p: line.project(p))
    segments = []

    # Process each point, inserting it into the correct location
    start_idx = 0
    start_pt = None
    end_pt = None

    for point in sorted_points:
        # Find the closest segment on the line using the spatial index
        nearest_pt_idx = sindex.nearest(point)
        end_idx = nearest_pt_idx
        end_pt = point

        dist1 = line.project(point)
        dist2 = line.project(line_coords[nearest_pt_idx])

        if dist1 > dist2:
            end_idx = nearest_pt_idx + 1

        # Create a new segment
        new_coords = line_coords[start_idx:end_idx]
        if start_pt:  # Append start point
            new_coords = [start_pt] + new_coords

        if end_pt:  # Append end point
            new_coords = new_coords + [end_pt]

        nearest_segment = LineString(new_coords)
        start_idx = end_idx
        start_pt = end_pt

        segments.append(nearest_segment)

    # Add remaining part of the line after the last point
    if start_idx < len(line_coords):
        # If last point is not close to end point of line
        if start_pt.distance(line_coords[-1]) > EPSILON:
            remaining_part = LineString([start_pt] + line_coords[end_idx:])
            segments.append(remaining_part)

    return segments

find_intersections()

Find intersections between lines in the GeoDataFrame.

Return: List of Point geometries where the lines intersect.

Source code in beratools/core/algo_split_with_lines.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def find_intersections(self):
    """
    Find intersections between lines in the GeoDataFrame.

    Return:
    List of Point geometries where the lines intersect.

    """
    visited_pairs = set()
    intersection_points = []

    # Iterate through each line geometry to find intersections
    for idx, line1 in enumerate(self.line_gdf.geometry):
        # Use spatial index to find candidates for intersection
        indices = list(self.sindex.intersection(line1.bounds))
        indices.remove(idx)  # Remove the current index from the list

        for match_idx in indices:
            line2 = self.line_gdf.iloc[match_idx].geometry

            # Create an index pair where the smaller index comes first
            pair = tuple(sorted([idx, match_idx]))

            # Skip if this pair has already been visited
            if pair in visited_pairs:
                continue

            # Mark the pair as visited
            visited_pairs.add(pair)

            # Only check lines that are different and intersect
            line1 = snap(line1, line2, tolerance=EPSILON)
            if line1.intersects(line2):
                # Find intersection points (can be multiple)
                intersections = line1.intersection(line2)

                if intersections.is_empty:
                    continue

                # Intersection can be Point, MultiPoint, LineString
                # or GeometryCollection
                if isinstance(intersections, Point):
                    intersection_points.append(intersections)
                else:
                    # record for further inspection
                    # GeometryCollection, MultiLineString
                    if isinstance(intersections, MultiPoint):
                        intersection_points.extend(intersections.geoms)
                    elif isinstance(intersections, LineString):
                        intersection_points.append(intersections.interpolate(0.5, normalized=True))

                    # if minimum distance between points is greater than threshold
                    # mark line as valid
                    if isinstance(intersections, MultiPoint):
                        if min_distance_in_multipoint(intersections) > algo_common.DISTANCE_THRESHOLD:
                            continue
                    # if intersection is a line, mark line as valid
                    if isinstance(intersections, LineString):
                        continue

                    for item in pair:
                        self.inter_status[item] = 0

    self.intersection_gdf = gpd.GeoDataFrame(geometry=intersection_points, crs=self.line_gdf.crs)

process(intersection_gdf=None)

Find intersection points, split lines at intersections.

Args: intersection_gdf: external GeoDataFrame with intersection points.

Source code in beratools/core/algo_split_with_lines.py
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
def process(self, intersection_gdf=None):
    """
    Find intersection points, split lines at intersections.

    Args:
    intersection_gdf: external GeoDataFrame with intersection points.

    """
    if intersection_gdf is not None:
        self.intersection_gdf = intersection_gdf
    else:
        self.find_intersections()

    if self.inter_status:
        for idx in self.inter_status.keys():
            self.line_gdf.loc[idx, INTER_STATUS_COL] = self.inter_status[idx]

    if not self.intersection_gdf.empty:
        # Split the lines at intersection points
        self.split_lines_at_intersections()
    else:
        print("No intersection points found, no lines to split.")

save_to_geopackage(input_gpkg, line_layer='split_lines', intersection_layer=None, invalid_layer=None)

Save the split lines and intersection points to the GeoPackage.

Args: line_layer: split lines layer name in the GeoPackage. intersection_layer: layer name for intersection points in the GeoPackage.

Source code in beratools/core/algo_split_with_lines.py
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
def save_to_geopackage(
    self,
    input_gpkg,
    line_layer="split_lines",
    intersection_layer=None,
    invalid_layer=None,
):
    """
    Save the split lines and intersection points to the GeoPackage.

    Args:
    line_layer: split lines layer name in the GeoPackage.
    intersection_layer: layer name for intersection points in the GeoPackage.

    """
    # Save intersection points and split lines to the GeoPackage
    if self.split_lines_gdf is not None and intersection_layer:
        if len(self.intersection_gdf) > 0:
            self.intersection_gdf.to_file(input_gpkg, layer=intersection_layer, driver="GPKG")

    if self.split_lines_gdf is not None and line_layer:
        if len(self.split_lines_gdf) > 0:
            self.split_lines_gdf["length"] = self.split_lines_gdf.geometry.length
            self.split_lines_gdf.to_file(input_gpkg, layer=line_layer, driver="GPKG")

    # save invalid splits
    invalid_splits = self.line_gdf.loc[self.line_gdf[INTER_STATUS_COL] == 0]
    if not invalid_splits.empty and invalid_layer:
        if len(invalid_splits) > 0:
            invalid_splits.to_file(input_gpkg, layer=invalid_layer, driver="GPKG")

split_lines_at_intersections()

Split lines at the given intersection points.

Args: intersection_points: List of Point geometries where the lines should be split.

Returns: A GeoDataFrame with the split lines.

Source code in beratools/core/algo_split_with_lines.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def split_lines_at_intersections(self):
    """
    Split lines at the given intersection points.

    Args:
    intersection_points: List of Point geometries where the lines should be split.

    Returns:
    A GeoDataFrame with the split lines.

    """
    # Create a spatial index for faster point-line intersection checks
    sindex = self.intersection_gdf.sindex

    # List to hold the new split line segments
    new_rows = []

    # Iterate through each intersection point to split lines at that point
    for row in self.line_gdf.itertuples():
        if not isinstance(row.geometry, LineString):
            continue

        # Use spatial index to find possible line candidates for intersection
        possible_matches = sindex.query(row.geometry.buffer(EPSILON))
        end_pts = MultiPoint([row.geometry.coords[0], row.geometry.coords[-1]])

        pt_list = []
        new_segments = [row.geometry]

        for idx in possible_matches:
            point = self.intersection_gdf.loc[idx].geometry
            # Check if the point is on the line
            if row.geometry.distance(point) < EPSILON:
                if end_pts.distance(point) < EPSILON:
                    continue
                else:
                    pt_list.append(point)

        if len(pt_list) > 0:
            # Split the line at the intersection
            new_segments = self.cut_line_by_points(row.geometry, pt_list)

        # If the line was split into multiple segments, create new rows
        for segment in new_segments:
            new_row = row._asdict()  # Convert the original row into a dictionary
            new_row["geometry"] = segment  # Update the geometry with the split one
            new_rows.append(new_row)

    self.split_lines_gdf = gpd.GeoDataFrame(
        new_rows, columns=self.line_gdf.columns, crs=self.line_gdf.crs
    )

    self.split_lines_gdf = algo_common.clean_line_geometries(self.split_lines_gdf)

    # Debugging: print how many segments were created
    print(f"Total new line segments created: {len(new_rows)}")

Copyright (C) 2025 Applied Geospatial Research Group.

This script is licensed under the GNU General Public License v3.0. See https://gnu.org/licenses/gpl-3.0 for full license details.

Author: Richard Zeng

Description

This script is part of the BERA Tools. Webpage: https://github.com/appliedgrg/beratools

The purpose of this script is to move line vertices to the right seismic line courses for improved alignment and analysis in geospatial data processing.

VertexGrouping

A class used to group vertices and perform vertex optimization.

Source code in beratools/core/algo_vertex_optimization.py
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
class VertexGrouping:
    """A class used to group vertices and perform vertex optimization."""

    def __init__(
        self,
        in_line,
        in_raster,
        search_distance,
        line_radius,
        out_line,
        processes,
        verbose,
        in_layer=None,
        out_layer=None,
    ):
        self.in_line = in_line
        self.in_raster = in_raster
        self.line_radius = float(line_radius)
        self.search_distance = float(search_distance)
        self.out_line = out_line
        self.processes = processes
        self.verbose = verbose
        self.parallel_mode = bt_const.PARALLEL_MODE
        self.in_layer = in_layer
        self.out_layer = out_layer

        self.crs = None
        self.vertex_grp = []
        self.sindex = None

        self.line_list = []
        self.line_visited = None

        # calculate cost raster footprint
        self.cost_footprint = algo_common.generate_raster_footprint(self.in_raster, latlon=False)

    def set_parallel_mode(self, parallel_mode):
        self.parallel_mode = parallel_mode

    def create_vertex_group(self, line_obj):
        """
        Create a new vertex group.

        Args:
            line_obj : _SingleLine

        """
        # all end points not added will stay with this vertex
        vertex = line_obj.get_end_vertex()
        vertex_obj = _Vertex(line_obj)
        search = self.sindex.query(vertex.buffer(bt_const.SMALL_BUFFER))

        # add more vertices to the new group
        for i in search:
            line = self.line_list[i]
            if i == line_obj.line_no:
                continue

            if not self.line_visited[i][0]:
                new_line = _SingleLine(line, i, 0, self.search_distance)
                if new_line.touches_point(vertex):
                    vertex_obj.add_line(new_line)
                    self.line_visited[i][0] = True

            if not self.line_visited[i][-1]:
                new_line = _SingleLine(line, i, -1, self.search_distance)
                if new_line.touches_point(vertex):
                    vertex_obj.add_line(new_line)
                    self.line_visited[i][-1] = True

        vertex_obj.in_raster = self.in_raster

        vertex_obj.line_radius = self.line_radius
        vertex_obj.cost_footprint = self.cost_footprint
        self.vertex_grp.append(vertex_obj)

    def create_all_vertex_groups(self):
        self.line_list = algo_common.prepare_lines_gdf(self.in_line, layer=self.in_layer, proc_segments=True)
        self.sindex = STRtree([item.geometry[0] for item in self.line_list])
        self.line_visited = [{0: False, -1: False} for _ in range(len(self.line_list))]

        i = 0
        for line_no in range(len(self.line_list)):
            if not self.line_visited[line_no][0]:
                line = _SingleLine(self.line_list[line_no], line_no, 0, self.search_distance)

                if not line.is_valid:
                    print(f"Line {line['line_no']} is invalid")
                    continue

                self.create_vertex_group(line)
                self.line_visited[line_no][0] = True
                i += 1

            if not self.line_visited[line_no][-1]:
                line = _SingleLine(self.line_list[line_no], line_no, -1, self.search_distance)

                if not line.is_valid:
                    print(f"Line {line['line_no']} is invalid")
                    continue

                self.create_vertex_group(line)
                self.line_visited[line_no][-1] = True
                i += 1

    def update_all_lines(self):
        for vertex_obj in self.vertex_grp:
            for line in vertex_obj.lines:
                if not vertex_obj.vertex_opt:
                    continue

                old_line = self.line_list[line.line_no].geometry[0]
                self.line_list[line.line_no].geometry = [
                    update_line_end_pt(old_line, line.end_no, vertex_obj.vertex_opt)
                ]

    def save_all_layers(self, line_file):
        line_file = Path(line_file)
        lines = pd.concat(self.line_list)
        lines.to_file(line_file, layer=self.out_layer)

        aux_file = line_file
        if line_file.suffix == ".shp":
            file_stem = line_file.stem
            aux_file = line_file.with_stem(file_stem + "_aux").with_suffix(".gpkg")

        lc_paths = []
        anchors = []
        vertices = []
        for item in self.vertex_grp:
            if item.centerlines:
                lc_paths.extend(item.centerlines)
            if item.anchors:
                anchors.extend(item.anchors)
            if item.vertex_opt:
                vertices.append(item.vertex_opt)

        lc_paths = [item for item in lc_paths if item is not None]
        anchors = [item for item in anchors if item is not None]
        vertices = [item for item in vertices if item is not None]

        lc_paths = gpd.GeoDataFrame(geometry=lc_paths, crs=lines.crs)
        anchors = gpd.GeoDataFrame(geometry=anchors, crs=lines.crs)
        vertices = gpd.GeoDataFrame(geometry=vertices, crs=lines.crs)

        lc_paths.to_file(aux_file, layer="lc_paths")
        anchors.to_file(aux_file, layer="anchors")
        vertices.to_file(aux_file, layer="vertices")

    def compute(self):
        vertex_grp = bt_base.execute_multiprocessing(
            algo_common.process_single_item,
            self.vertex_grp,
            "Vertex Optimization",
            self.processes,
            1,
            verbose=self.verbose,
        )

        self.vertex_grp = vertex_grp

create_vertex_group(line_obj)

Create a new vertex group.

Parameters:

Name Type Description Default
line_obj

_SingleLine

required
Source code in beratools/core/algo_vertex_optimization.py
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
def create_vertex_group(self, line_obj):
    """
    Create a new vertex group.

    Args:
        line_obj : _SingleLine

    """
    # all end points not added will stay with this vertex
    vertex = line_obj.get_end_vertex()
    vertex_obj = _Vertex(line_obj)
    search = self.sindex.query(vertex.buffer(bt_const.SMALL_BUFFER))

    # add more vertices to the new group
    for i in search:
        line = self.line_list[i]
        if i == line_obj.line_no:
            continue

        if not self.line_visited[i][0]:
            new_line = _SingleLine(line, i, 0, self.search_distance)
            if new_line.touches_point(vertex):
                vertex_obj.add_line(new_line)
                self.line_visited[i][0] = True

        if not self.line_visited[i][-1]:
            new_line = _SingleLine(line, i, -1, self.search_distance)
            if new_line.touches_point(vertex):
                vertex_obj.add_line(new_line)
                self.line_visited[i][-1] = True

    vertex_obj.in_raster = self.in_raster

    vertex_obj.line_radius = self.line_radius
    vertex_obj.cost_footprint = self.cost_footprint
    self.vertex_grp.append(vertex_obj)

Constants

Copyright (C) 2025 Applied Geospatial Research Group.

This script is licensed under the GNU General Public License v3.0. See https://gnu.org/licenses/gpl-3.0 for full license details.

Author: Richard Zeng

Description

This script is part of the BERA Tools. Webpage: https://github.com/appliedgrg/beratools

The purpose of this script is to provide common constants.

CenterlineFlags

Bases: Flag

Flags for the centerline algorithm.

Source code in beratools/core/constants.py
36
37
38
39
40
41
class CenterlineFlags(enum.Flag):
    """Flags for the centerline algorithm."""

    USE_SKIMAGE_GRAPH = False
    DELETE_HOLES = True
    SIMPLIFY_POLYGON = True

ParallelMode

Bases: IntEnum

Defines the parallel mode for the algorithms.

Source code in beratools/core/constants.py
44
45
46
47
48
49
50
51
52
@enum.unique
class ParallelMode(enum.IntEnum):
    """Defines the parallel mode for the algorithms."""

    SEQUENTIAL = 1
    MULTIPROCESSING = 2
    CONCURRENT = 3
    DASK = 4
    SLURM = 5

Logger

Copyright (C) 2025 Applied Geospatial Research Group.

This script is licensed under the GNU General Public License v3.0. See https://gnu.org/licenses/gpl-3.0 for full license details.

Author: Richard Zeng

Description

This script is part of the BERA Tools. Webpage: https://github.com/appliedgrg/beratools

The purpose of this script is to provide logger functions.

Logger

Bases: object

Logger class to handle logging in the BERA Tools application.

This class sets up a logger that outputs to both the console and a file. It allows for different logging levels for console and file outputs. It also provides a method to print messages directly to the logger.

Source code in beratools/core/logger.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
class Logger(object):
    """
    Logger class to handle logging in the BERA Tools application.

    This class sets up a logger that outputs to both the console and a file.
    It allows for different logging levels for console and file outputs.
    It also provides a method to print messages directly to the logger.
    """

    def __init__(self, name, console_level=logging.INFO, file_level=logging.INFO):
        self.logger = logging.getLogger(name)
        self.name = name
        self.console_level = console_level
        self.file_level = file_level

        self.setup_logger()

    def get_logger(self):
        return self.logger

    def print(self, msg, flush=True):
        """
        Re-define print in logging.

        Args:
        msg :
        flush :

        """
        self.logger.info(msg)
        if flush:
            for handler in self.logger.handlers:
                handler.flush()

    def setup_logger(self):
        # Change root logger level from WARNING (default) to NOTSET
        # in order for all messages to be delegated.
        logging.getLogger().setLevel(logging.NOTSET)
        log_file = bt.get_logger_file_name(self.name)

        # Add stdout handler, with level INFO
        console_handler = logging.StreamHandler(sys.stdout)
        console_handler.setLevel(self.console_level)
        formatter = logging.Formatter("%(message)s")
        console_handler.setFormatter(formatter)
        logging.getLogger().addHandler(console_handler)

        # Add file rotating handler, 5MB size limit, 5 backups
        rotating_handler = logging.handlers.RotatingFileHandler(
            filename=log_file, maxBytes=5 * 1000 * 1000, backupCount=5
        )

        rotating_handler.setLevel(self.file_level)
        formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
        rotating_handler.setFormatter(formatter)
        logging.getLogger().addHandler(rotating_handler)
        logging.getLogger().addFilter(NoParsingFilter())

print(msg, flush=True)

Re-define print in logging.

Args: msg : flush :

Source code in beratools/core/logger.py
56
57
58
59
60
61
62
63
64
65
66
67
68
def print(self, msg, flush=True):
    """
    Re-define print in logging.

    Args:
    msg :
    flush :

    """
    self.logger.info(msg)
    if flush:
        for handler in self.logger.handlers:
            handler.flush()

NoParsingFilter

Bases: Filter

Filter to exclude log messages that start with "parsing".

This is useful to avoid cluttering the log with parsing-related messages.

Source code in beratools/core/logger.py
25
26
27
28
29
30
31
32
33
class NoParsingFilter(logging.Filter):
    """
    Filter to exclude log messages that start with "parsing".

    This is useful to avoid cluttering the log with parsing-related messages.
    """

    def filter(self, record):
        return not record.getMessage().startswith("parsing")

Tool Base

Copyright (C) 2025 Applied Geospatial Research Group.

This script is licensed under the GNU General Public License v3.0. See https://gnu.org/licenses/gpl-3.0 for full license details.

Author: Richard Zeng

Description

This script is part of the BERA Tools. Webpage: https://github.com/appliedgrg/beratools

The purpose of this script is to provide fundamental utilities for tools.

ToolBase

Bases: object

Base class for tools.

Source code in beratools/core/tool_base.py
36
37
38
39
40
41
42
43
class ToolBase(object):
    """Base class for tools."""

    def __init__(self):
        pass

    def execute_multiprocessing(self):
        pass