diff --git a/src/imcflibs/imagej/objects3d.py b/src/imcflibs/imagej/objects3d.py index ce514c7..1c58cab 100644 --- a/src/imcflibs/imagej/objects3d.py +++ b/src/imcflibs/imagej/objects3d.py @@ -8,6 +8,7 @@ from de.mpicbg.scf.imgtools.image.create.image import ImageCreationUtilities from de.mpicbg.scf.imgtools.image.create.labelmap import WatershedLabeling from ij import IJ +from inra.ijpb.plugins import RemoveBorderLabelsPlugin from mcib3d.geom import Objects3DPopulation from mcib3d.image3d import ImageHandler, ImageLabeller from mcib3d.image3d.processing import MaximaFinder @@ -71,7 +72,15 @@ def imgplus_to_population3d(imp): return Objects3DPopulation(img) -def segment_3d_image(imp, title=None, min_thresh=1, min_vol=None, max_vol=None): +def segment_3d_image( + imp, + title=None, + min_thresh=1, + min_vol=None, + max_vol=None, + remove_touching_borders=False, + remove_touching_borders_z=False, +): """Segment a 3D binary image to get a labelled stack. Parameters @@ -90,6 +99,11 @@ def segment_3d_image(imp, title=None, min_thresh=1, min_vol=None, max_vol=None): max_vol : int, optional Maximum volume (in voxels) above which objects get filtered. Defaults to None. + remove_touching_borders : bool, optional + Whether to remove objects that touch the borders in X and Y. Defaults to False. + remove_touching_borders_z : bool, optional + Whether to remove objects that touch the z-axis borders. Defaults to False. + Returns ------- @@ -107,14 +121,24 @@ def segment_3d_image(imp, title=None, min_thresh=1, min_vol=None, max_vol=None): labeler.setMinSizeCalibrated(min_vol, img) if max_vol: labeler.setMaxSizeCalibrated(max_vol, img) - # Generate labelled segmentation seg = labeler.getLabels(img) seg.setScale(cal.pixelWidth, cal.pixelDepth, cal.getUnits()) + + seg = RemoveBorderLabelsPlugin().remove( + seg.getImagePlus(), + remove_touching_borders, + remove_touching_borders, + remove_touching_borders, + remove_touching_borders, + remove_touching_borders_z, + remove_touching_borders_z, + ) + if title: seg.setTitle(title) - return seg.getImagePlus() + return seg def maxima_finder_3d(imp, min_threshold=0, noise=100, rxy=1.5, rz=1.5):