
  1# Copyright 2022 vkit-x Administrator. All Rights Reserved.
  3# This project (vkit-x/vkit) is dual-licensed under commercial and SSPL licenses.
  5# The commercial license gives you the full rights to create and distribute software
  6# on your own terms without any SSPL license obligations. For more information,
  7# please see the "LICENSE_COMMERCIAL.txt" file.
  9# This project is also available under Server Side Public License (SSPL).
 10# The SSPL licensing is ideal for use cases such as open source projects with
 11# SSPL distribution, student/academic purposes, hobby projects, internal research
 12# projects without external distribution, or other projects where all SSPL
 13# obligations can be met. For more information, please see the "LICENSE_SSPL.txt" file.
 14from typing import Sequence, List, Dict, Optional
 15import bisect
 16import heapq
 18import attrs
 19from numpy.random import Generator as RandomGenerator
 20import numpy as np
 21import cv2 as cv
 22import iolite as io
 24from vkit.utility import rng_choice, read_json_file
 25from vkit.element import Image, ImageMode, Mask
 26from vkit.mechanism.distortion import rotate
 27from ..interface import (
 28    Engine,
 29    EngineExecutorFactory,
 30    NoneTypeEngineInitResource,
 32from .type import ImageEngineRunConfig
 36class ImageMeta:
 37    image_file: str
 38    grayscale_mean: float
 39    grayscale_std: float
 42class FolderTree:
 43    IMAGE = 'image'
 44    METAS_JSON = 'metas.json'
 47def load_image_metas_from_folder(folder: str):
 48    in_fd = io.folder(folder, expandvars=True, exists=True)
 49    image_fd = io.folder(
 50        in_fd / FolderTree.IMAGE,
 51        exists=True,
 52    )
 53    metas_json = io.file(
 54        in_fd / FolderTree.METAS_JSON,
 55        exists=True,
 56    )
 58    image_metas: List[ImageMeta] = []
 59    for meta in read_json_file(metas_json):
 60        image_file = io.file(image_fd / meta['image_file'], exists=True)
 61        image_metas.append(
 62            ImageMeta(
 63                image_file=str(image_file),
 64                grayscale_mean=meta['grayscale_mean'],
 65                grayscale_std=meta['grayscale_std'],
 66            )
 67        )
 69    return image_metas
 73class ImageCombinerEngineInitConfig:
 74    image_meta_folder: str
 75    target_image_mode: ImageMode = ImageMode.RGB
 76    enable_cache: bool = False
 77    prob_use_only_the_anchor_image: float = 0.7
 78    prob_rotate_image: float = 0.5
 79    sigma: float = 3.0
 80    init_segment_width_min_ratio: float = 0.25
 81    gaussian_blur_kernel_size = 5
 85class PrioritizedSegment:
 86    y: int = attrs.field(order=True)
 87    left: int = attrs.field(order=False)
 88    right: int = attrs.field(order=False)
 91class ImageCombinerEngine(
 92    Engine[
 93        ImageCombinerEngineInitConfig,
 94        NoneTypeEngineInitResource,
 95        ImageEngineRunConfig,
 96        Image,
 97    ]
 98):  # yapf: disable
100    @classmethod
101    def get_type_name(cls) -> str:
102        return 'combiner'
104    def __init__(
105        self,
106        init_config: ImageCombinerEngineInitConfig,
107        init_resource: Optional[NoneTypeEngineInitResource] = None,
108    ):
109        super().__init__(init_config, init_resource)
111        self.image_metas = load_image_metas_from_folder(init_config.image_meta_folder)
112        self.image_metas = sorted(
113            self.image_metas,
114            key=lambda meta: meta.grayscale_mean,
115        )
116        self.image_metas_grayscale_means = [
117            image_meta.grayscale_mean for image_meta in self.image_metas
118        ]
119        self.enable_cache = init_config.enable_cache
120        self.image_file_to_cache_image: Dict[str, Image] = {}
122    def sample_image_metas_based_on_random_anchor(
123        self,
124        run_config: ImageEngineRunConfig,
125        rng: RandomGenerator,
126    ):
127        # Get candidates based on anchor.
128        anchor_image_meta = rng_choice(rng, self.image_metas)
130        if rng.random() < self.init_config.prob_use_only_the_anchor_image:
131            return [anchor_image_meta]
133        else:
134            grayscale_std = anchor_image_meta.grayscale_std
135            grayscale_mean = anchor_image_meta.grayscale_mean
137            grayscale_begin = round(grayscale_mean - self.init_config.sigma * grayscale_std)
138            grayscale_end = round(grayscale_mean + self.init_config.sigma * grayscale_std)
140            index_begin = bisect.bisect_left(self.image_metas_grayscale_means, x=grayscale_begin)
141            index_end = bisect.bisect_right(self.image_metas_grayscale_means, x=grayscale_end)
142            image_metas = self.image_metas[index_begin:index_end]
143            assert image_metas
144            return image_metas
146    @classmethod
147    def fill_np_edge_mask(
148        cls,
149        np_edge_mask: np.ndarray,
150        height: int,
151        width: int,
152        gaussian_blur_half_kernel_size: int,
153        up: int,
154        down: int,
155        left: int,
156        right: int,
157    ):
158        # Fill up.
159        up_min = max(0, up - gaussian_blur_half_kernel_size)
160        up_max = min(height - 1, up + gaussian_blur_half_kernel_size)
161        np_edge_mask[up_min:up_max + 1, left:right + 1] = 1
163        # Fill down.
164        down_min = max(0, down - gaussian_blur_half_kernel_size)
165        down_max = min(height - 1, down + gaussian_blur_half_kernel_size)
166        np_edge_mask[down_min:down_max + 1, left:right + 1] = 1
168        # Fill left.
169        left_min = max(0, left - gaussian_blur_half_kernel_size)
170        left_max = min(width - 1, left + gaussian_blur_half_kernel_size)
171        np_edge_mask[up:down + 1, left_min:left_max + 1] = 1
173        # Fill right.
174        right_min = max(0, right - gaussian_blur_half_kernel_size)
175        right_max = min(width - 1, right + gaussian_blur_half_kernel_size)
176        np_edge_mask[up:down + 1, right_min:right_max + 1] = 1
178    def synthesize_image(
179        self,
180        run_config: ImageEngineRunConfig,
181        image_metas: Sequence[ImageMeta],
182        rng: RandomGenerator,
183    ):
184        height = run_config.height
185        width = run_config.width
187        mat = np.zeros((height, width, 3), dtype=np.uint8)
188        edge_mask = Mask.from_shape((height, width))
189        gaussian_blur_half_kernel_size = self.init_config.gaussian_blur_kernel_size // 2 + 1
191        # Initialize segments.
192        priority_queue: List[PrioritizedSegment] = []
193        segment_width_min = int(
194            np.clip(
195                round(self.init_config.init_segment_width_min_ratio * width),
196                1,
197                width - 1,
198            )
199        )
200        left = 0
201        while left + segment_width_min - 1 < width:
202            right = rng.integers(
203                left + segment_width_min - 1,
204                width,
205            )
206            if right + 1 - left < segment_width_min or width - right - 1 < segment_width_min:
207                break
208            priority_queue.append(PrioritizedSegment(
209                y=0,
210                left=left,
211                right=right,
212            ))
213            left = right + 1
214        if left < width:
215            priority_queue.append(PrioritizedSegment(
216                y=0,
217                left=left,
218                right=width - 1,
219            ))
221        # For random rotation.
222        image_file_to_rotate_flag: Dict[str, bool] = {}
224        while priority_queue:
225            # Pop a segment
226            cur_segment = heapq.heappop(priority_queue)
228            # Deal with connection.
229            segments: List[PrioritizedSegment] = []
230            while priority_queue and priority_queue[0].y == cur_segment.y:
231                segments.append(heapq.heappop(priority_queue))
233            if segments:
234                segments.append(cur_segment)
235                segments = sorted(segments, key=lambda segment: segment.left)
236                cur_segment_idx = -1
237                for segment_idx, segment in enumerate(segments):
238                    if segment.left == cur_segment.left and segment.right == cur_segment.right:
239                        cur_segment_idx = segment_idx
240                        break
241                assert cur_segment_idx >= 0
243                begin = cur_segment_idx
244                while begin > 0 and segments[begin - 1].right + 1 == segments[begin].left:
245                    begin -= 1
246                end = cur_segment_idx
247                while end + 1 < len(segments) and segments[end].right + 1 == segments[end + 1].left:
248                    end += 1
250                if begin < end:
251                    # Update the current segment.
252                    cur_segment.left = segments[begin].left
253                    cur_segment.right = segments[end].right
255                # Push back.
256                for segment in segments[:begin]:
257                    heapq.heappush(priority_queue, segment)
258                for segment in segments[end + 1:]:
259                    heapq.heappush(priority_queue, segment)
261            # Load image.
262            image_meta = rng_choice(rng, image_metas)
264            if self.enable_cache and image_meta.image_file in self.image_file_to_cache_image:
265                segment_image = self.image_file_to_cache_image[image_meta.image_file]
267            else:
268                segment_image = Image.from_file(image_meta.image_file).to_target_mode_image(
269                    self.init_config.target_image_mode
270                )
272                if image_meta.image_file not in image_file_to_rotate_flag:
273                    rotate_flag = (rng.random() < self.init_config.prob_rotate_image)
274                    image_file_to_rotate_flag[image_meta.image_file] = rotate_flag
276                if image_file_to_rotate_flag[image_meta.image_file]:
277                    segment_image = rotate.distort_image(
278                        {'angle': 90},
279                        image=segment_image,
280                    )
282                if self.enable_cache:
283                    self.image_file_to_cache_image[image_meta.image_file] = segment_image
285            # Fill image and edge mask.
286            up = cur_segment.y
287            down = min(height - 1, up + segment_image.height - 1)
288            left = cur_segment.left
289            right = min(cur_segment.right, left + segment_image.width - 1)
290            mat[up:down + 1, left:right + 1] = \
291                segment_image.mat[:down + 1 - up, :right + 1 - left]
293            with edge_mask.writable_context:
294                self.fill_np_edge_mask(
295                    np_edge_mask=edge_mask.mat,
296                    height=height,
297                    width=width,
298                    gaussian_blur_half_kernel_size=gaussian_blur_half_kernel_size,
299                    up=up,
300                    down=down,
301                    left=left,
302                    right=right,
303                )
305            # Update segments.
306            if right == cur_segment.right:
307                # Reach the current right end.
308                cur_segment.y = down + 1
309                if cur_segment.y < height:
310                    heapq.heappush(priority_queue, cur_segment)
311            else:
312                # Not reaching the right end.
313                assert right < cur_segment.right
314                new_segment = PrioritizedSegment(
315                    y=down + 1,
316                    left=left,
317                    right=right,
318                )
319                if new_segment.y < height:
320                    heapq.heappush(priority_queue, new_segment)
322                cur_segment.left = right + 1
323                heapq.heappush(priority_queue, cur_segment)
325        # Apply gaussian blur.
326        gaussian_blur_sigma = gaussian_blur_half_kernel_size / 3
327        gaussian_blur_ksize = (self.init_config.gaussian_blur_kernel_size,) * 2
328        edge_mask.fill_np_array(
329            mat,
330            cv.GaussianBlur(mat, gaussian_blur_ksize, gaussian_blur_sigma),
331        )
333        return Image(mat=mat)
335    def run(
336        self,
337        run_config: ImageEngineRunConfig,
338        rng: Optional[RandomGenerator] = None,
339    ) -> Image:
340        assert rng is not None
342        assert not run_config.disable_resizing
343        image_metas = self.sample_image_metas_based_on_random_anchor(run_config, rng)
344        return self.synthesize_image(run_config, image_metas, rng)
347image_combiner_engine_executor_factory = EngineExecutorFactory(ImageCombinerEngine)
