From 86b38ce38b4790759922cad02bd4d5d56e86d2b6 Mon Sep 17 00:00:00 2001
From: Narcon Nicolas <nicolas.narcon@inrae.fr>
Date: Wed, 23 Mar 2022 15:08:11 +0100
Subject: [PATCH 01/17] ENH: add TFRecord class + to_tfrecords method + split
 "array" data from "scalar" data + add max_nb_of_samples for Dataset

---
 python/otbtf.py  | 241 +++++++++++++++++++++++++++++++++++++++++++++--
 python/system.py |  68 +++++++++++++
 2 files changed, 301 insertions(+), 8 deletions(-)
 create mode 100644 python/system.py

diff --git a/python/otbtf.py b/python/otbtf.py
index a23d5237..45e13791 100644
--- a/python/otbtf.py
+++ b/python/otbtf.py
@@ -18,17 +18,23 @@
 #
 # ==========================================================================*/
 """
-Contains stuff to help working with TensorFlow and geospatial data in the
-OTBTF framework.
+Contains stuff to help working with TensorFlow and geospatial data in the OTBTF framework.
 """
+import glob
+import json
+import os
 import threading
 import multiprocessing
 import time
 import logging
 from abc import ABC, abstractmethod
+from functools import partial
+from tqdm import tqdm
+
 import numpy as np
 import tensorflow as tf
 import gdal
+import system
 
 
 # ----------------------------------------------------- Helpers --------------------------------------------------------
@@ -167,13 +173,18 @@ class PatchesImagesReader(PatchesReaderBase):
     :see PatchesReaderBase
     """
 
-    def __init__(self, filenames_dict: dict, use_streaming=False):
+    def __init__(self, filenames_dict, scalar_dict=None, use_streaming=False):
         """
         :param filenames_dict: A dict() structured as follow:
             {src_name1: [src1_patches_image_1.tif, ..., src1_patches_image_N.tif],
              src_name2: [src2_patches_image_1.tif, ..., src2_patches_image_N.tif],
              ...
              src_nameM: [srcM_patches_image_1.tif, ..., srcM_patches_image_N.tif]}
+        :param scalar_dict: (optional) a dict containing list of scalars (int, float, str) as follow:
+            {scalar_name1: ["value_1", ..., "value_N"],
+             scalar_name2: [value_1, ..., value_N],
+             ...
+             scalar_nameN: [value1, ..., value_N]}
         :param use_streaming: if True, the patches are read on the fly from the disc, nothing is kept in memory.
         """
 
@@ -182,8 +193,13 @@ class PatchesImagesReader(PatchesReaderBase):
         # gdal_ds dict
         self.gdal_ds = {key: [gdal_open(src_fn) for src_fn in src_fns] for key, src_fns in filenames_dict.items()}
 
+        # Scalar parameters (e.g. metadatas)
+        self.scalar_dict = scalar_dict
+        if scalar_dict is None:
+            self.scalar_dict = {}
+
         # check number of patches in each sources
-        if len({len(ds_list) for ds_list in self.gdal_ds.values()}) != 1:
+        if len({len(ds_list) for ds_list in list(self.gdal_ds.values()) + list(self.scalar_dict.values())}) != 1:
             raise Exception("Each source must have the same number of patches images")
 
         # streaming on/off
@@ -213,6 +229,12 @@ class PatchesImagesReader(PatchesReaderBase):
         if not self.use_streaming:
             patches_list = {src_key: [read_as_np_arr(ds) for ds in self.gdal_ds[src_key]] for src_key in self.gdal_ds}
             self.patches_buffer = {src_key: np.concatenate(patches_list[src_key], axis=0) for src_key in self.gdal_ds}
+            # Create a scalars dict so that one scalar <-> one patch
+            self.scalar_buffer = {}
+            for src_key, scalars in self.scalar_dict.items():
+                self.scalar_buffer[src_key] = []
+                for scalar, ds_size in zip(scalars, self.ds_sizes):
+                    self.scalar_buffer[src_key].extend([scalar] * ds_size)
 
     def _get_ds_and_offset_from_index(self, index):
         offset = index
@@ -254,9 +276,11 @@ class PatchesImagesReader(PatchesReaderBase):
 
         if not self.use_streaming:
             res = {src_key: self.patches_buffer[src_key][index, :, :, :] for src_key in self.gdal_ds}
+            res.update({key: np.asarray(scalars[index]) for key, scalars in self.scalar_buffer.items()})
         else:
             i, offset = self._get_ds_and_offset_from_index(index)
             res = {src_key: self._read_extract_as_np_arr(self.gdal_ds[src_key][i], offset) for src_key in self.gdal_ds}
+            res.update({key: np.asarray(scalars[i]) for key, scalars in self.scalar_dict.items()})
 
         return res
 
@@ -362,16 +386,18 @@ class Dataset:
     """
 
     def __init__(self, patches_reader: PatchesReaderBase, buffer_length: int = 128,
-                 Iterator: IteratorBase = RandomIterator):
+                 Iterator=RandomIterator, max_nb_of_samples=None):
         """
         :param patches_reader: The patches reader instance
         :param buffer_length: The number of samples that are stored in the buffer
         :param Iterator: The iterator class used to generate the sequence of patches indices.
+        :param max_nb_of_samples: Optional, max number of samples to consider
         """
 
         # patches reader
         self.patches_reader = patches_reader
-        self.size = self.patches_reader.get_size()
+        self.size = min(self.patches_reader.get_size(),
+                        max_nb_of_samples) if max_nb_of_samples else self.patches_reader.get_size()
 
         # iterator
         self.iterator = Iterator(patches_reader=self.patches_reader)
@@ -380,6 +406,7 @@ class Dataset:
         self.output_types = dict()
         self.output_shapes = dict()
         one_sample = self.patches_reader.get_sample(index=0)
+        print(one_sample)  # DEBUG
         for src_key, np_arr in one_sample.items():
             self.output_shapes[src_key] = np_arr.shape
             self.output_types[src_key] = tf.dtypes.as_dtype(np_arr.dtype)
@@ -404,6 +431,14 @@ class Dataset:
                                                          output_types=self.output_types,
                                                          output_shapes=self.output_shapes).repeat(1)
 
+    def to_tfrecords(self, output_dir, n_samples_per_shard=100, drop_remainder=True):
+        """
+
+        """
+        tfrecord = TFRecords(output_dir)
+        tfrecord.ds2tfrecord(self, n_samples_per_shard=n_samples_per_shard, drop_remainder=drop_remainder)
+
+
     def get_stats(self) -> dict:
         """
         :return: the dataset statistics, computed by the patches reader
@@ -502,8 +537,8 @@ class DatasetFromPatchesImages(Dataset):
     :see Dataset
     """
 
-    def __init__(self, filenames_dict: dict, use_streaming: bool = False, buffer_length: int = 128,
-                 Iterator: IteratorBase = RandomIterator):
+    def __init__(self, filenames_dict, use_streaming=False, buffer_length: int = 128,
+                 Iterator=RandomIterator):
         """
         :param filenames_dict: A dict() structured as follow:
             {src_name1: [src1_patches_image1, ..., src1_patches_imageN1],
@@ -518,3 +553,193 @@ class DatasetFromPatchesImages(Dataset):
         patches_reader = PatchesImagesReader(filenames_dict=filenames_dict, use_streaming=use_streaming)
 
         super().__init__(patches_reader=patches_reader, buffer_length=buffer_length, Iterator=Iterator)
+
+
+class TFRecords:
+    """
+    This class allows to convert Dataset objects to TFRecords and to load them in dataset tensorflows format.
+    """
+
+    def __init__(self, path):
+        """
+        :param path: Can be a directory where TFRecords must be saved/loaded or a single TFRecord path
+        """
+        if system.is_dir(path) or not os.path.exists(path):
+            self.dirpath = path
+            system.mkdir(self.dirpath)
+            self.tfrecords_pattern_path = "{}*.records".format(system.pathify(self.dirpath))
+        else:
+            self.dirpath = system.dirname(path)
+            self.tfrecords_pattern_path = path
+        self.output_types_file = "{}output_types.json".format(system.pathify(self.dirpath))
+        self.output_shape_file = "{}output_shape.json".format(system.pathify(self.dirpath))
+        self.output_shape = self.load(self.output_shape_file) if os.path.exists(self.output_shape_file) else None
+        self.output_types = self.load(self.output_types_file) if os.path.exists(self.output_types_file) else None
+
+    def _bytes_feature(self, value):
+        """
+        Used to convert a value to a type compatible with tf.train.Example.
+        :param value: value
+        :return a bytes_list from a string / byte.
+        """
+        if isinstance(value, type(tf.constant(0))):
+            value = value.numpy()  # BytesList won't unpack a string from an EagerTensor.
+        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
+
+    def ds2tfrecord(self, dataset, n_samples_per_shard=100, drop_remainder=True):
+        """
+        Convert and save samples from dataset object to tfrecord files.
+        :param dataset: Dataset object to convert into a set of tfrecords
+        :param n_samples_per_shard: Number of samples per shard
+        :param drop_remainder: Whether additional samples should be dropped. Advisable if using multiworkers training.
+                               If True, all TFRecords will have `n_samples_per_shard` samples
+        """
+        logging.info("%s samples", dataset.size)
+
+        nb_shards = (dataset.size // n_samples_per_shard)
+        if not drop_remainder and dataset.size % n_samples_per_shard > 0:
+            nb_shards += 1
+
+        self.convert_dataset_output_shapes(dataset)
+
+        def _convert_data(data):
+            """
+            Convert data
+            """
+            data_converted = {}
+
+            for k, d in data.items():
+                data_converted[k] = d.name
+
+            return data_converted
+
+        self.save(_convert_data(dataset.output_types), self.output_types_file)
+
+        for i in tqdm(range(nb_shards)):
+
+            if (i + 1) * n_samples_per_shard <= dataset.size:
+                nb_sample = n_samples_per_shard
+            else:
+                nb_sample = dataset.size - i * n_samples_per_shard
+
+            filepath = "{}{}.records".format(system.pathify(self.dirpath), i)
+            with tf.io.TFRecordWriter(filepath) as writer:
+                for s in range(nb_sample):
+                    sample = dataset.read_one_sample()
+                    serialized_sample = {name: tf.io.serialize_tensor(fea) for name, fea in sample.items()}
+                    features = {name: self._bytes_feature(serialized_tensor) for name, serialized_tensor in
+                                serialized_sample.items()}
+                    tf_features = tf.train.Features(feature=features)
+                    example = tf.train.Example(features=tf_features)
+                    writer.write(example.SerializeToString())
+
+    @staticmethod
+    def save(data, filepath):
+        """
+        Save data to pickle format.
+        :param data: Data to save json format
+        :param filepath: Output file name
+        """
+
+        with open(filepath, 'w') as f:
+            json.dump(data, f, indent=4)
+
+    @staticmethod
+    def load(filepath):
+        """
+        Return data from pickle format.
+        :param filepath: Input file name
+        """
+        with open(filepath, 'r') as f:
+            return json.load(f)
+
+    def convert_dataset_output_shapes(self, dataset):
+        """
+        Convert and save numpy shape to tensorflow shape.
+        :param dataset: Dataset object containing output shapes
+        """
+        output_shapes = {}
+
+        for key in dataset.output_shapes.keys():
+            output_shapes[key] = (None,) + dataset.output_shapes[key]
+
+        self.save(output_shapes, self.output_shape_file)
+
+    @staticmethod
+    def parse_tfrecord(example, features_types, target_keys):
+        """
+        Parse example object to sample dict.
+        :param example: Example object to parse
+        :param features_types: List of types for each feature
+        :param target_keys: list of keys of the targets
+        """
+        read_features = {key: tf.io.FixedLenFeature([], dtype=tf.string) for key in features_types}
+        example_parsed = tf.io.parse_single_example(example, read_features)
+
+        for key in read_features.keys():
+            example_parsed[key] = tf.io.parse_tensor(example_parsed[key], out_type=features_types[key])
+
+        # Differentiating inputs and outputs
+        input_parsed = {key: value for (key, value) in example_parsed.items() if key not in target_keys}
+        target_parsed = {key: value for (key, value) in example_parsed.items() if key in target_keys}
+
+        return input_parsed, target_parsed
+
+
+    def read(self, batch_size, target_keys, n_workers=1, drop_remainder=True, shuffle_buffer_size=None):
+        """
+        Read all tfrecord files matching with pattern and convert data to tensorflow dataset.
+        :param batch_size: Size of tensorflow batch
+        :param target_key: Key of the target, e.g. 's2_out'
+        :param n_workers: number of workers, e.g. 4 if using 4 GPUs
+                                             e.g. 12 if using 3 nodes of 4 GPUs
+        :param drop_remainder: whether the last batch should be dropped in the case it has fewer than
+                               `batch_size` elements. True is advisable when training on multiworkers.
+                               False is advisable when evaluating metrics so that all samples are used
+        :param shuffle_buffer_size: is None, shuffle is not used. Else, blocks of shuffle_buffer_size
+                                    elements are shuffled using uniform random.
+        """
+        options = tf.data.Options()
+        if shuffle_buffer_size:
+            options.experimental_deterministic = False  # disable order, increase speed
+        options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.AUTO  # for multiworker
+        parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys)
+
+        # TODO: to be investigated :
+        # 1/ num_parallel_reads useful ? I/O bottleneck of not ?
+        # 2/ num_parallel_calls=tf.data.experimental.AUTOTUNE useful ?
+        # 3/ shuffle or not shuffle ?
+        matching_files = glob.glob(self.tfrecords_pattern_path)
+        logging.info('Searching TFRecords in %s...', self.tfrecords_pattern_path)
+        logging.info('Number of matching TFRecords: %s', len(matching_files))
+        matching_files = matching_files[:n_workers * (len(matching_files) // n_workers)]  # files multiple of workers
+        nb_matching_files = len(matching_files)
+        if nb_matching_files == 0:
+            raise Exception("At least one worker has no TFRecord file in {}. Please ensure that the number of TFRecord "
+                            "files is greater or equal than the number of workers!".format(self.tfrecords_pattern_path))
+        logging.info('Reducing number of records to : %s', nb_matching_files)
+        dataset = tf.data.TFRecordDataset(matching_files)  # , num_parallel_reads=2)  # interleaves reads from xxx files
+        dataset = dataset.with_options(options)  # uses data as soon as it streams in, rather than in its original order
+        dataset = dataset.map(parse, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+        if shuffle_buffer_size:
+            dataset = dataset.shuffle(buffer_size=shuffle_buffer_size)
+        dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
+        dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
+        # TODO voir si on met le prefetch avant le batch cf https://keras.io/examples/keras_recipes/tfrecord/
+
+        return dataset
+
+    def read_one_sample(self, target_keys):
+        """
+        Read one tfrecord file matching with pattern and convert data to tensorflow dataset.
+        :param target_key: Key of the target, e.g. 's2_out'
+        """
+        matching_files = glob.glob(self.tfrecords_pattern_path)
+        one_file = matching_files[0]
+        parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys)
+        dataset = tf.data.TFRecordDataset(one_file)
+        dataset = dataset.map(parse)
+        dataset = dataset.batch(1)
+
+        sample = iter(dataset).get_next()
+        return sample
diff --git a/python/system.py b/python/system.py
new file mode 100644
index 00000000..e7b581fd
--- /dev/null
+++ b/python/system.py
@@ -0,0 +1,68 @@
+# -*- coding: utf-8 -*-
+"""
+Copyright (c) 2020-2022 INRAE
+
+Permission is hereby granted, free of charge, to any person obtaining a
+copy of this software and associated documentation files (the "Software"),
+to deal in the Software without restriction, including without limitation
+the rights to use, copy, modify, merge, publish, distribute, sublicense,
+and/or sell copies of the Software, and to permit persons to whom the
+Software is furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
+THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+DEALINGS IN THE SOFTWARE.
+"""
+"""Various system operations"""
+import logging
+import pathlib
+import os
+
+# ---------------------------------------------------- Helpers ---------------------------------------------------------
+
+def pathify(pth):
+    """ Adds posix separator if needed """
+    if not pth.endswith("/"):
+        pth += "/"
+    return pth
+
+
+def mkdir(pth):
+    """ Create a directory """
+    path = pathlib.Path(pth)
+    path.mkdir(parents=True, exist_ok=True)
+
+
+def dirname(filename):
+    """ Returns the parent directory of the file """
+    return str(pathlib.Path(filename).parent)
+
+
+def basic_logging_init():
+    """ basic logging initialization """
+    logging.basicConfig(
+        format='%(asctime)s %(levelname)-8s %(message)s',
+        level=logging.INFO,
+        datefmt='%Y-%m-%d %H:%M:%S')
+
+
+def logging_info(msg, verbose=True):
+    """
+    Prints log info only if required by `verbose`
+    :param msg: message to log
+    :param verbose: boolean. Whether to log msg or not. Default True
+    :return:
+    """
+    if verbose:
+        logging.info(msg)
+
+def is_dir(filename):
+    """ return True if filename is the path to a directory """
+    return os.path.isdir(filename)
-- 
GitLab


From 28e489468ece5134fd441df1609a8ded77a95531 Mon Sep 17 00:00:00 2001
From: Narcon Nicolas <nicolas.narcon@inrae.fr>
Date: Thu, 24 Mar 2022 16:43:46 +0100
Subject: [PATCH 02/17] ENH: split the Dataset initialisation and the
 patch_reader feeding

---
 python/otbtf.py | 15 +++++++++++++--
 1 file changed, 13 insertions(+), 2 deletions(-)

diff --git a/python/otbtf.py b/python/otbtf.py
index 45e13791..8f8f9359 100644
--- a/python/otbtf.py
+++ b/python/otbtf.py
@@ -385,7 +385,7 @@ class Dataset:
     :see Buffer
     """
 
-    def __init__(self, patches_reader: PatchesReaderBase, buffer_length: int = 128,
+    def __init__(self, patches_reader: PatchesReaderBase = None, buffer_length: int = 128,
                  Iterator=RandomIterator, max_nb_of_samples=None):
         """
         :param patches_reader: The patches reader instance
@@ -393,8 +393,19 @@ class Dataset:
         :param Iterator: The iterator class used to generate the sequence of patches indices.
         :param max_nb_of_samples: Optional, max number of samples to consider
         """
-
         # patches reader
+        if patches_reader:
+            self.feed(patches_reader, buffer_length, Iterator, max_nb_of_samples)
+
+
+    def feed(self, patches_reader: PatchesReaderBase = None, buffer_length: int = 128,
+                 Iterator=RandomIterator, max_nb_of_samples=None):
+        """
+        :param patches_reader: The patches reader instance
+        :param buffer_length: The number of samples that are stored in the buffer
+        :param Iterator: The iterator class used to generate the sequence of patches indices.
+        :param max_nb_of_samples: Optional, max number of samples to consider
+        """
         self.patches_reader = patches_reader
         self.size = min(self.patches_reader.get_size(),
                         max_nb_of_samples) if max_nb_of_samples else self.patches_reader.get_size()
-- 
GitLab


From c774e9b62687d07abeb9a8829aace58a2da8a0c0 Mon Sep 17 00:00:00 2001
From: Narcon Nicolas <nicolas.narcon@inrae.fr>
Date: Wed, 30 Mar 2022 14:48:09 +0200
Subject: [PATCH 03/17] ENH: add more log info

---
 python/otbtf.py | 10 ++++++++--
 1 file changed, 8 insertions(+), 2 deletions(-)

diff --git a/python/otbtf.py b/python/otbtf.py
index 8f8f9359..17fd06c0 100644
--- a/python/otbtf.py
+++ b/python/otbtf.py
@@ -407,8 +407,14 @@ class Dataset:
         :param max_nb_of_samples: Optional, max number of samples to consider
         """
         self.patches_reader = patches_reader
-        self.size = min(self.patches_reader.get_size(),
-                        max_nb_of_samples) if max_nb_of_samples else self.patches_reader.get_size()
+
+        # If necessary, limit the nb of samples
+        logging.info('There are %s samples available', self.patches_reader.get_size())
+        if max_nb_of_samples and self.patches_reader.get_size() > max_nb_of_samples:
+            logging.info('Reducing number of samples to %s', max_nb_of_samples)
+            self.size = max_nb_of_samples
+        else:
+            self.size = self.patches_reader.get_size()
 
         # iterator
         self.iterator = Iterator(patches_reader=self.patches_reader)
-- 
GitLab


From 120a7ba69b75400dc56c8aa70ff5902cb8a91cc0 Mon Sep 17 00:00:00 2001
From: Narcon Nicolas <nicolas.narcon@inrae.fr>
Date: Mon, 4 Apr 2022 16:36:03 +0200
Subject: [PATCH 04/17] FIX: remove API breaker

---
 python/otbtf.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/python/otbtf.py b/python/otbtf.py
index 17fd06c0..d13d5788 100644
--- a/python/otbtf.py
+++ b/python/otbtf.py
@@ -173,19 +173,19 @@ class PatchesImagesReader(PatchesReaderBase):
     :see PatchesReaderBase
     """
 
-    def __init__(self, filenames_dict, scalar_dict=None, use_streaming=False):
+    def __init__(self, filenames_dict, use_streaming=False, scalar_dict=None):
         """
         :param filenames_dict: A dict() structured as follow:
             {src_name1: [src1_patches_image_1.tif, ..., src1_patches_image_N.tif],
              src_name2: [src2_patches_image_1.tif, ..., src2_patches_image_N.tif],
              ...
              src_nameM: [srcM_patches_image_1.tif, ..., srcM_patches_image_N.tif]}
+        :param use_streaming: if True, the patches are read on the fly from the disc, nothing is kept in memory.
         :param scalar_dict: (optional) a dict containing list of scalars (int, float, str) as follow:
             {scalar_name1: ["value_1", ..., "value_N"],
              scalar_name2: [value_1, ..., value_N],
              ...
-             scalar_nameN: [value1, ..., value_N]}
-        :param use_streaming: if True, the patches are read on the fly from the disc, nothing is kept in memory.
+             scalar_nameM: [value1, ..., valueN]}
         """
 
         assert len(filenames_dict.values()) > 0
-- 
GitLab


From 817a69c7806e7dda1976916339743c0067d78304 Mon Sep 17 00:00:00 2001
From: Narcon Nicolas <nicolas.narcon@inrae.fr>
Date: Mon, 4 Apr 2022 16:55:39 +0200
Subject: [PATCH 05/17] FIX: add tqdm dependency

---
 Dockerfile | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/Dockerfile b/Dockerfile
index 34b9b4a4..ece649b3 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -25,7 +25,7 @@ RUN if $GUI; then \
 RUN ln -s /usr/bin/python3 /usr/local/bin/python && ln -s /usr/bin/pip3 /usr/local/bin/pip
 # NumPy version is conflicting with system's gdal dep and may require venv
 ARG NUMPY_SPEC="==1.19.*"
-RUN pip install --no-cache-dir -U pip wheel mock six future deprecated "numpy$NUMPY_SPEC" \
+RUN pip install --no-cache-dir -U pip wheel mock six future tqdm deprecated "numpy$NUMPY_SPEC" \
  && pip install --no-cache-dir --no-deps keras_applications keras_preprocessing
 
 # ----------------------------------------------------------------------------
-- 
GitLab


From 0da55b2105cb256d2f6212ebe06033002a8c0164 Mon Sep 17 00:00:00 2001
From: Narcon Nicolas <nicolas.narcon@inrae.fr>
Date: Mon, 4 Apr 2022 17:02:24 +0200
Subject: [PATCH 06/17] REFAC: replace system by plain python os library

---
 python/otbtf.py  | 15 +++++------
 python/system.py | 68 ------------------------------------------------
 2 files changed, 7 insertions(+), 76 deletions(-)
 delete mode 100644 python/system.py

diff --git a/python/otbtf.py b/python/otbtf.py
index d13d5788..c702375e 100644
--- a/python/otbtf.py
+++ b/python/otbtf.py
@@ -34,7 +34,6 @@ from tqdm import tqdm
 import numpy as np
 import tensorflow as tf
 import gdal
-import system
 
 
 # ----------------------------------------------------- Helpers --------------------------------------------------------
@@ -581,15 +580,15 @@ class TFRecords:
         """
         :param path: Can be a directory where TFRecords must be saved/loaded or a single TFRecord path
         """
-        if system.is_dir(path) or not os.path.exists(path):
+        if os.path.isdir(path) or not os.path.exists(path):
             self.dirpath = path
-            system.mkdir(self.dirpath)
-            self.tfrecords_pattern_path = "{}*.records".format(system.pathify(self.dirpath))
+            os.makedirs(self.dirpath, exist_ok=True)
+            self.tfrecords_pattern_path = os.path.join(self.dirpath, "*.records")
         else:
-            self.dirpath = system.dirname(path)
+            self.dirpath = os.path.dirname(path)
             self.tfrecords_pattern_path = path
-        self.output_types_file = "{}output_types.json".format(system.pathify(self.dirpath))
-        self.output_shape_file = "{}output_shape.json".format(system.pathify(self.dirpath))
+        self.output_types_file = os.path.join(self.dirpath, "output_types.json")
+        self.output_shape_file = os.path.join(self.dirpath, "output_shape.json")
         self.output_shape = self.load(self.output_shape_file) if os.path.exists(self.output_shape_file) else None
         self.output_types = self.load(self.output_types_file) if os.path.exists(self.output_types_file) else None
 
@@ -639,7 +638,7 @@ class TFRecords:
             else:
                 nb_sample = dataset.size - i * n_samples_per_shard
 
-            filepath = "{}{}.records".format(system.pathify(self.dirpath), i)
+            filepath = os.path.join(self.dirpath, f"{i}.records")
             with tf.io.TFRecordWriter(filepath) as writer:
                 for s in range(nb_sample):
                     sample = dataset.read_one_sample()
diff --git a/python/system.py b/python/system.py
deleted file mode 100644
index e7b581fd..00000000
--- a/python/system.py
+++ /dev/null
@@ -1,68 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-Copyright (c) 2020-2022 INRAE
-
-Permission is hereby granted, free of charge, to any person obtaining a
-copy of this software and associated documentation files (the "Software"),
-to deal in the Software without restriction, including without limitation
-the rights to use, copy, modify, merge, publish, distribute, sublicense,
-and/or sell copies of the Software, and to permit persons to whom the
-Software is furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in
-all copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
-THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
-FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
-DEALINGS IN THE SOFTWARE.
-"""
-"""Various system operations"""
-import logging
-import pathlib
-import os
-
-# ---------------------------------------------------- Helpers ---------------------------------------------------------
-
-def pathify(pth):
-    """ Adds posix separator if needed """
-    if not pth.endswith("/"):
-        pth += "/"
-    return pth
-
-
-def mkdir(pth):
-    """ Create a directory """
-    path = pathlib.Path(pth)
-    path.mkdir(parents=True, exist_ok=True)
-
-
-def dirname(filename):
-    """ Returns the parent directory of the file """
-    return str(pathlib.Path(filename).parent)
-
-
-def basic_logging_init():
-    """ basic logging initialization """
-    logging.basicConfig(
-        format='%(asctime)s %(levelname)-8s %(message)s',
-        level=logging.INFO,
-        datefmt='%Y-%m-%d %H:%M:%S')
-
-
-def logging_info(msg, verbose=True):
-    """
-    Prints log info only if required by `verbose`
-    :param msg: message to log
-    :param verbose: boolean. Whether to log msg or not. Default True
-    :return:
-    """
-    if verbose:
-        logging.info(msg)
-
-def is_dir(filename):
-    """ return True if filename is the path to a directory """
-    return os.path.isdir(filename)
-- 
GitLab


From e82173f7268ba63dea90510b2503fd72fb7fa1e1 Mon Sep 17 00:00:00 2001
From: Narcon Nicolas <nicolas.narcon@inrae.fr>
Date: Tue, 12 Apr 2022 12:24:38 +0200
Subject: [PATCH 07/17] FIX: make `_read_extract_as_np_arr` method return 3D
 arrays even for singleband

---
 python/otbtf.py | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/python/otbtf.py b/python/otbtf.py
index c702375e..71e2f6a3 100644
--- a/python/otbtf.py
+++ b/python/otbtf.py
@@ -258,6 +258,9 @@ class PatchesImagesReader(PatchesReaderBase):
         buffer = gdal_ds.ReadAsArray(0, yoff, psz, psz)
         if len(buffer.shape) == 3:
             buffer = np.transpose(buffer, axes=(1, 2, 0))
+        else:  # single-band raster
+            buffer = np.expand_dims(buffer, axis=2)
+
         return np.float32(buffer)
 
     def get_sample(self, index):
-- 
GitLab


From 4ab9bdb59daae0f6adfbec9990bd4ff9b970398f Mon Sep 17 00:00:00 2001
From: Narcon Nicolas <nicolas.narcon@inrae.fr>
Date: Tue, 19 Apr 2022 16:41:39 +0200
Subject: [PATCH 08/17] ENH: add the possibility to specify cropping of the
 target when reading TFRecords

---
 python/otbtf.py | 17 ++++++++++++-----
 1 file changed, 12 insertions(+), 5 deletions(-)

diff --git a/python/otbtf.py b/python/otbtf.py
index 61a0767f..598eab65 100644
--- a/python/otbtf.py
+++ b/python/otbtf.py
@@ -685,12 +685,13 @@ class TFRecords:
         self.save(output_shapes, self.output_shape_file)
 
     @staticmethod
-    def parse_tfrecord(example, features_types, target_keys):
+    def parse_tfrecord(example, features_types, target_keys, target_cropping=None):
         """
         Parse example object to sample dict.
         :param example: Example object to parse
         :param features_types: List of types for each feature
         :param target_keys: list of keys of the targets
+        :param target_cropping: Optional. Number of pixels to be removed on each side of the target tensor.
         """
         read_features = {key: tf.io.FixedLenFeature([], dtype=tf.string) for key in features_types}
         example_parsed = tf.io.parse_single_example(example, read_features)
@@ -702,27 +703,33 @@ class TFRecords:
         input_parsed = {key: value for (key, value) in example_parsed.items() if key not in target_keys}
         target_parsed = {key: value for (key, value) in example_parsed.items() if key in target_keys}
 
+        if target_cropping:
+            print({key: value for key, value in target_parsed.items()})
+            target_parsed = {key: value[target_cropping:-target_cropping, target_cropping:-target_cropping, :] for key, value in target_parsed.items()}
+
         return input_parsed, target_parsed
 
 
-    def read(self, batch_size, target_keys, n_workers=1, drop_remainder=True, shuffle_buffer_size=None):
+    def read(self, batch_size, target_keys, target_cropping=None, n_workers=1, drop_remainder=True, shuffle_buffer_size=None):
         """
         Read all tfrecord files matching with pattern and convert data to tensorflow dataset.
         :param batch_size: Size of tensorflow batch
-        :param target_key: Key of the target, e.g. 's2_out'
+        :param target_keys: Keys of the target, e.g. ['s2_out']
+        :param target_cropping: Number of pixels to be removed on each side of the target. Must be used with a network
+                               architecture coherent with this, i.e. that has a Cropping2D layer in the end
         :param n_workers: number of workers, e.g. 4 if using 4 GPUs
                                              e.g. 12 if using 3 nodes of 4 GPUs
         :param drop_remainder: whether the last batch should be dropped in the case it has fewer than
                                `batch_size` elements. True is advisable when training on multiworkers.
                                False is advisable when evaluating metrics so that all samples are used
-        :param shuffle_buffer_size: is None, shuffle is not used. Else, blocks of shuffle_buffer_size
+        :param shuffle_buffer_size: if None, shuffle is not used. Else, blocks of shuffle_buffer_size
                                     elements are shuffled using uniform random.
         """
         options = tf.data.Options()
         if shuffle_buffer_size:
             options.experimental_deterministic = False  # disable order, increase speed
         options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.AUTO  # for multiworker
-        parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys)
+        parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys, target_cropping=target_cropping)
 
         # TODO: to be investigated :
         # 1/ num_parallel_reads useful ? I/O bottleneck of not ?
-- 
GitLab


From 31a4f931a29899431e62cbaabd6d86747c3ecd37 Mon Sep 17 00:00:00 2001
From: Narcon Nicolas <nicolas.narcon@inrae.fr>
Date: Tue, 19 Apr 2022 16:53:59 +0200
Subject: [PATCH 09/17] STYLE: remove debug prints

---
 python/otbtf.py | 2 --
 1 file changed, 2 deletions(-)

diff --git a/python/otbtf.py b/python/otbtf.py
index 598eab65..c944728c 100644
--- a/python/otbtf.py
+++ b/python/otbtf.py
@@ -425,7 +425,6 @@ class Dataset:
         self.output_types = dict()
         self.output_shapes = dict()
         one_sample = self.patches_reader.get_sample(index=0)
-        print(one_sample)  # DEBUG
         for src_key, np_arr in one_sample.items():
             self.output_shapes[src_key] = np_arr.shape
             self.output_types[src_key] = tf.dtypes.as_dtype(np_arr.dtype)
@@ -704,7 +703,6 @@ class TFRecords:
         target_parsed = {key: value for (key, value) in example_parsed.items() if key in target_keys}
 
         if target_cropping:
-            print({key: value for key, value in target_parsed.items()})
             target_parsed = {key: value[target_cropping:-target_cropping, target_cropping:-target_cropping, :] for key, value in target_parsed.items()}
 
         return input_parsed, target_parsed
-- 
GitLab


From c3eb4703503a45e5cc8b9a4ea409c7341d46558d Mon Sep 17 00:00:00 2001
From: Remi Cresson <remi.cresson@irstea.fr>
Date: Wed, 20 Apr 2022 10:49:10 +0200
Subject: [PATCH 10/17] ADD: modifications

---
 python/otbtf.py | 74 ++++++++++++++++++-------------------------------
 1 file changed, 27 insertions(+), 47 deletions(-)

diff --git a/python/otbtf.py b/python/otbtf.py
index c944728c..922d13db 100644
--- a/python/otbtf.py
+++ b/python/otbtf.py
@@ -172,7 +172,7 @@ class PatchesImagesReader(PatchesReaderBase):
     :see PatchesReaderBase
     """
 
-    def __init__(self, filenames_dict, use_streaming=False, scalar_dict=None):
+    def __init__(self, filenames_dict, use_streaming=False, scalar_dict={}):
         """
         :param filenames_dict: A dict() structured as follow:
             {src_name1: [src1_patches_image_1.tif, ..., src1_patches_image_N.tif],
@@ -192,18 +192,18 @@ class PatchesImagesReader(PatchesReaderBase):
         # gdal_ds dict
         self.gdal_ds = {key: [gdal_open(src_fn) for src_fn in src_fns] for key, src_fns in filenames_dict.items()}
 
-        # Scalar parameters (e.g. metadatas)
-        self.scalar_dict = scalar_dict
-        if scalar_dict is None:
-            self.scalar_dict = {}
+        # streaming on/off
+        self.use_streaming = use_streaming
+
+        # Scalar dict (e.g. for metadata)
+        # If the scalars are not numpy.ndarray, convert them
+        self.scalar_dict = {key: [i if isinstance(i, np.ndarray) else np.asarray(i) for i in scalars]
+                            for key, scalars in scalar_dict.items()}
 
         # check number of patches in each sources
         if len({len(ds_list) for ds_list in list(self.gdal_ds.values()) + list(self.scalar_dict.values())}) != 1:
             raise Exception("Each source must have the same number of patches images")
 
-        # streaming on/off
-        self.use_streaming = use_streaming
-
         # gdal_ds check
         nb_of_patches = {key: 0 for key in self.gdal_ds}
         self.nb_of_channels = dict()
@@ -226,14 +226,8 @@ class PatchesImagesReader(PatchesReaderBase):
 
         # if use_streaming is False, we store in memory all patches images
         if not self.use_streaming:
-            patches_list = {src_key: [read_as_np_arr(ds) for ds in self.gdal_ds[src_key]] for src_key in self.gdal_ds}
-            self.patches_buffer = {src_key: np.concatenate(patches_list[src_key], axis=0) for src_key in self.gdal_ds}
-            # Create a scalars dict so that one scalar <-> one patch
-            self.scalar_buffer = {}
-            for src_key, scalars in self.scalar_dict.items():
-                self.scalar_buffer[src_key] = []
-                for scalar, ds_size in zip(scalars, self.ds_sizes):
-                    self.scalar_buffer[src_key].extend([scalar] * ds_size)
+            self.patches_buffer = {src_key: np.concatenate([read_as_np_arr(ds) for ds in src_ds[src_key]], axis=0) for
+                                   src_key, src_ds in self.gdal_ds.items()}
 
     def _get_ds_and_offset_from_index(self, index):
         offset = index
@@ -276,20 +270,19 @@ class PatchesImagesReader(PatchesReaderBase):
         assert index >= 0
         assert index < self.size
 
+        i, offset = self._get_ds_and_offset_from_index(index)
+        res = {src_key: scalar[i] for src_key, scalar in self.scalar_dict.items()}
         if not self.use_streaming:
-            res = {src_key: self.patches_buffer[src_key][index, :, :, :] for src_key in self.gdal_ds}
-            res.update({key: np.asarray(scalars[index]) for key, scalars in self.scalar_buffer.items()})
+            res.update({src_key: arr[index, :, :, :] for src_key, arr in self.patches_buffer.items()})
         else:
-            i, offset = self._get_ds_and_offset_from_index(index)
-            res = {src_key: self._read_extract_as_np_arr(self.gdal_ds[src_key][i], offset) for src_key in self.gdal_ds}
-            res.update({key: np.asarray(scalars[i]) for key, scalars in self.scalar_dict.items()})
-
+            res.update({src_key: self._read_extract_as_np_arr(self.gdal_ds[src_key][i], offset)
+                        for src_key in self.gdal_ds})
         return res
 
     def get_stats(self):
         """
         Compute some statistics for each source.
-        Depending if streaming is used, the statistics are computed directly in memory, or chunk-by-chunk.
+        When streaming is used, chunk-by-chunk. Else, the statistics are computed directly in memory.
 
         :return statistics dict
         """
@@ -340,6 +333,7 @@ class IteratorBase(ABC):
     """
     Base class for iterators
     """
+
     @abstractmethod
     def __init__(self, patches_reader: PatchesReaderBase):
         pass
@@ -396,22 +390,10 @@ class Dataset:
         :param max_nb_of_samples: Optional, max number of samples to consider
         """
         # patches reader
-        if patches_reader:
-            self.feed(patches_reader, buffer_length, Iterator, max_nb_of_samples)
-
-
-    def feed(self, patches_reader: PatchesReaderBase = None, buffer_length: int = 128,
-                 Iterator=RandomIterator, max_nb_of_samples=None):
-        """
-        :param patches_reader: The patches reader instance
-        :param buffer_length: The number of samples that are stored in the buffer
-        :param Iterator: The iterator class used to generate the sequence of patches indices.
-        :param max_nb_of_samples: Optional, max number of samples to consider
-        """
         self.patches_reader = patches_reader
 
         # If necessary, limit the nb of samples
-        logging.info('There are %s samples available', self.patches_reader.get_size())
+        logging.info('Number of samples: %s', self.patches_reader.get_size())
         if max_nb_of_samples and self.patches_reader.get_size() > max_nb_of_samples:
             logging.info('Reducing number of samples to %s', max_nb_of_samples)
             self.size = max_nb_of_samples
@@ -451,14 +433,19 @@ class Dataset:
 
     def to_tfrecords(self, output_dir, n_samples_per_shard=100, drop_remainder=True):
         """
+        Save the dataset into TFRecord files
 
+        :param output_dir: output directory
+        :param n_samples_per_shard: number of samples per TFRecord file
+        :param drop_remainder: drop remainder samples
         """
         tfrecord = TFRecords(output_dir)
         tfrecord.ds2tfrecord(self, n_samples_per_shard=n_samples_per_shard, drop_remainder=drop_remainder)
 
-
     def get_stats(self) -> dict:
         """
+        Compute dataset statistics
+
         :return: the dataset statistics, computed by the patches reader
         """
         with self.mining_lock:
@@ -684,13 +671,12 @@ class TFRecords:
         self.save(output_shapes, self.output_shape_file)
 
     @staticmethod
-    def parse_tfrecord(example, features_types, target_keys, target_cropping=None):
+    def parse_tfrecord(example, features_types, target_keys):
         """
         Parse example object to sample dict.
         :param example: Example object to parse
         :param features_types: List of types for each feature
         :param target_keys: list of keys of the targets
-        :param target_cropping: Optional. Number of pixels to be removed on each side of the target tensor.
         """
         read_features = {key: tf.io.FixedLenFeature([], dtype=tf.string) for key in features_types}
         example_parsed = tf.io.parse_single_example(example, read_features)
@@ -702,19 +688,13 @@ class TFRecords:
         input_parsed = {key: value for (key, value) in example_parsed.items() if key not in target_keys}
         target_parsed = {key: value for (key, value) in example_parsed.items() if key in target_keys}
 
-        if target_cropping:
-            target_parsed = {key: value[target_cropping:-target_cropping, target_cropping:-target_cropping, :] for key, value in target_parsed.items()}
-
         return input_parsed, target_parsed
 
-
-    def read(self, batch_size, target_keys, target_cropping=None, n_workers=1, drop_remainder=True, shuffle_buffer_size=None):
+    def read(self, batch_size, target_keys, n_workers=1, drop_remainder=True, shuffle_buffer_size=None):
         """
         Read all tfrecord files matching with pattern and convert data to tensorflow dataset.
         :param batch_size: Size of tensorflow batch
         :param target_keys: Keys of the target, e.g. ['s2_out']
-        :param target_cropping: Number of pixels to be removed on each side of the target. Must be used with a network
-                               architecture coherent with this, i.e. that has a Cropping2D layer in the end
         :param n_workers: number of workers, e.g. 4 if using 4 GPUs
                                              e.g. 12 if using 3 nodes of 4 GPUs
         :param drop_remainder: whether the last batch should be dropped in the case it has fewer than
@@ -727,7 +707,7 @@ class TFRecords:
         if shuffle_buffer_size:
             options.experimental_deterministic = False  # disable order, increase speed
         options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.AUTO  # for multiworker
-        parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys, target_cropping=target_cropping)
+        parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys)
 
         # TODO: to be investigated :
         # 1/ num_parallel_reads useful ? I/O bottleneck of not ?
-- 
GitLab


From 5186eb5e0d6a771f437e7666b576fee7769c521c Mon Sep 17 00:00:00 2001
From: Remi Cresson <remi.cresson@irstea.fr>
Date: Wed, 20 Apr 2022 11:38:47 +0200
Subject: [PATCH 11/17] ENH: use default arg as None instead {}

---
 python/otbtf.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/python/otbtf.py b/python/otbtf.py
index 922d13db..ad77b7a2 100644
--- a/python/otbtf.py
+++ b/python/otbtf.py
@@ -172,7 +172,7 @@ class PatchesImagesReader(PatchesReaderBase):
     :see PatchesReaderBase
     """
 
-    def __init__(self, filenames_dict, use_streaming=False, scalar_dict={}):
+    def __init__(self, filenames_dict, use_streaming=False, scalar_dict=None):
         """
         :param filenames_dict: A dict() structured as follow:
             {src_name1: [src1_patches_image_1.tif, ..., src1_patches_image_N.tif],
@@ -198,7 +198,7 @@ class PatchesImagesReader(PatchesReaderBase):
         # Scalar dict (e.g. for metadata)
         # If the scalars are not numpy.ndarray, convert them
         self.scalar_dict = {key: [i if isinstance(i, np.ndarray) else np.asarray(i) for i in scalars]
-                            for key, scalars in scalar_dict.items()}
+                            for key, scalars in scalar_dict.items()} if scalar_dict else {}
 
         # check number of patches in each sources
         if len({len(ds_list) for ds_list in list(self.gdal_ds.values()) + list(self.scalar_dict.values())}) != 1:
-- 
GitLab


From 878dde7dd0ab49a28d18acf9ad5907e1a4e70f70 Mon Sep 17 00:00:00 2001
From: Remi Cresson <remi.cresson@irstea.fr>
Date: Wed, 20 Apr 2022 11:44:40 +0200
Subject: [PATCH 12/17] REFAC: change import order

---
 python/otbtf.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/python/otbtf.py b/python/otbtf.py
index ad77b7a2..cbb96e55 100644
--- a/python/otbtf.py
+++ b/python/otbtf.py
@@ -29,11 +29,10 @@ import time
 import logging
 from abc import ABC, abstractmethod
 from functools import partial
-from tqdm import tqdm
-
 import numpy as np
 import tensorflow as tf
 from osgeo import gdal
+from tqdm import tqdm
 
 
 # ----------------------------------------------------- Helpers --------------------------------------------------------
@@ -581,9 +580,10 @@ class TFRecords:
         self.output_shape = self.load(self.output_shape_file) if os.path.exists(self.output_shape_file) else None
         self.output_types = self.load(self.output_types_file) if os.path.exists(self.output_types_file) else None
 
+    @staticmethod
     def _bytes_feature(self, value):
         """
-        Used to convert a value to a type compatible with tf.train.Example.
+        Convert a value to a type compatible with tf.train.Example.
         :param value: value
         :return a bytes_list from a string / byte.
         """
-- 
GitLab


From d223155ef2b4b616fe9d4eaddc2d1d760f61fdba Mon Sep 17 00:00:00 2001
From: Remi Cresson <remi.cresson@irstea.fr>
Date: Thu, 21 Apr 2022 14:25:40 +0200
Subject: [PATCH 13/17] FIX: list indices must be integers

---
 python/otbtf.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/python/otbtf.py b/python/otbtf.py
index cbb96e55..83cb4fe5 100644
--- a/python/otbtf.py
+++ b/python/otbtf.py
@@ -225,7 +225,7 @@ class PatchesImagesReader(PatchesReaderBase):
 
         # if use_streaming is False, we store in memory all patches images
         if not self.use_streaming:
-            self.patches_buffer = {src_key: np.concatenate([read_as_np_arr(ds) for ds in src_ds[src_key]], axis=0) for
+            self.patches_buffer = {src_key: np.concatenate([read_as_np_arr(ds) for ds in src_ds], axis=0) for
                                    src_key, src_ds in self.gdal_ds.items()}
 
     def _get_ds_and_offset_from_index(self, index):
-- 
GitLab


From 1a2a42d0eb5ff61ff7266d702d21ef9578b27d4e Mon Sep 17 00:00:00 2001
From: Remi Cresson <remi.cresson@irstea.fr>
Date: Thu, 21 Apr 2022 14:27:34 +0200
Subject: [PATCH 14/17] FIX: static function

---
 python/otbtf.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/python/otbtf.py b/python/otbtf.py
index 83cb4fe5..75d188ae 100644
--- a/python/otbtf.py
+++ b/python/otbtf.py
@@ -581,7 +581,7 @@ class TFRecords:
         self.output_types = self.load(self.output_types_file) if os.path.exists(self.output_types_file) else None
 
     @staticmethod
-    def _bytes_feature(self, value):
+    def _bytes_feature(value):
         """
         Convert a value to a type compatible with tf.train.Example.
         :param value: value
-- 
GitLab


From c3717f3b86991d160af74ee09b7bf55a05371efb Mon Sep 17 00:00:00 2001
From: Narcon Nicolas <nicolas.narcon@inrae.fr>
Date: Wed, 20 Apr 2022 12:08:33 +0200
Subject: [PATCH 15/17] (Cherrypick from 14) generalize cropping target to a
 preprocessing function

---
 python/otbtf.py | 16 ++++++++++++++--
 1 file changed, 14 insertions(+), 2 deletions(-)

diff --git a/python/otbtf.py b/python/otbtf.py
index 75d188ae..c180a7c4 100644
--- a/python/otbtf.py
+++ b/python/otbtf.py
@@ -672,11 +672,15 @@ class TFRecords:
 
     @staticmethod
     def parse_tfrecord(example, features_types, target_keys):
+    def parse_tfrecord(example, features_types, target_keys, preprocessing_fn=None, **kwargs):
         """
         Parse example object to sample dict.
         :param example: Example object to parse
         :param features_types: List of types for each feature
         :param target_keys: list of keys of the targets
+        :param preprocessing_fn: Optional. A preprocessing function that takes input, target as args and returns
+                                           a tuple (input_preprocessed, target_preprocessed)
+        :param kwargs: some keywords arguments for preprocessing_fn
         """
         read_features = {key: tf.io.FixedLenFeature([], dtype=tf.string) for key in features_types}
         example_parsed = tf.io.parse_single_example(example, read_features)
@@ -688,9 +692,13 @@ class TFRecords:
         input_parsed = {key: value for (key, value) in example_parsed.items() if key not in target_keys}
         target_parsed = {key: value for (key, value) in example_parsed.items() if key in target_keys}
 
+        if preprocessing_fn:
+            input_parsed, target_parsed = preprocessing_fn(input_parsed, target_parsed, **kwargs)
+
         return input_parsed, target_parsed
 
-    def read(self, batch_size, target_keys, n_workers=1, drop_remainder=True, shuffle_buffer_size=None):
+    def read(self, batch_size, target_keys, n_workers=1, drop_remainder=True, shuffle_buffer_size=None,
+             preprocessing_fn=None, **kwargs):
         """
         Read all tfrecord files matching with pattern and convert data to tensorflow dataset.
         :param batch_size: Size of tensorflow batch
@@ -702,12 +710,16 @@ class TFRecords:
                                False is advisable when evaluating metrics so that all samples are used
         :param shuffle_buffer_size: if None, shuffle is not used. Else, blocks of shuffle_buffer_size
                                     elements are shuffled using uniform random.
+        :param preprocessing_fn: Optional. A preprocessing function that takes input, target as args and returns
+                                   a tuple (input_preprocessed, target_preprocessed)
+        :param kwargs: some keywords arguments for preprocessing_fn
         """
         options = tf.data.Options()
         if shuffle_buffer_size:
             options.experimental_deterministic = False  # disable order, increase speed
         options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.AUTO  # for multiworker
-        parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys)
+        parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys,
+                        preprocessing_fn=preprocessing_fn, **kwargs)
 
         # TODO: to be investigated :
         # 1/ num_parallel_reads useful ? I/O bottleneck of not ?
-- 
GitLab


From 161e33e8cdc9f7024cb5c081f0d0561fdc300b30 Mon Sep 17 00:00:00 2001
From: Vincent Delbar <vincent.delbar@latelescop.fr>
Date: Thu, 21 Apr 2022 15:55:49 +0200
Subject: [PATCH 16/17] FIX: indented bloc

---
 python/otbtf.py | 1 -
 1 file changed, 1 deletion(-)

diff --git a/python/otbtf.py b/python/otbtf.py
index c180a7c4..b28a1cc4 100644
--- a/python/otbtf.py
+++ b/python/otbtf.py
@@ -671,7 +671,6 @@ class TFRecords:
         self.save(output_shapes, self.output_shape_file)
 
     @staticmethod
-    def parse_tfrecord(example, features_types, target_keys):
     def parse_tfrecord(example, features_types, target_keys, preprocessing_fn=None, **kwargs):
         """
         Parse example object to sample dict.
-- 
GitLab


From 929dae88414026ac9dc760cb01d23e7d99a4a7ce Mon Sep 17 00:00:00 2001
From: Narcon Nicolas <nicolas.narcon@inrae.fr>
Date: Thu, 21 Apr 2022 16:55:58 +0200
Subject: [PATCH 17/17] ENH: generate samples of same type as initial raster

---
 python/otbtf.py | 12 +++++++++---
 1 file changed, 9 insertions(+), 3 deletions(-)

diff --git a/python/otbtf.py b/python/otbtf.py
index b28a1cc4..a1cf9bd4 100644
--- a/python/otbtf.py
+++ b/python/otbtf.py
@@ -58,8 +58,11 @@ def read_as_np_arr(gdal_ds, as_patches=True):
         False, the shape is (1, psz_y, psz_x, nb_channels)
     :return: Numpy array of dim 4
     """
-    buffer = gdal_ds.ReadAsArray()
+    gdal_to_np_types = {1: 'uint8', 2: 'uint16', 3: 'int16', 4: 'uint32', 5: 'int32', 6: 'float32', 7: 'float64',
+                        10: 'complex64', 11: 'complex128'}
+    gdal_type = gdal_ds.GetRasterBand(1).DataType
     size_x = gdal_ds.RasterXSize
+    buffer = gdal_ds.ReadAsArray().astype(gdal_to_np_types[gdal_type])
     if len(buffer.shape) == 3:
         buffer = np.transpose(buffer, axes=(1, 2, 0))
     if not as_patches:
@@ -68,7 +71,7 @@ def read_as_np_arr(gdal_ds, as_patches=True):
     else:
         n_elems = int(gdal_ds.RasterYSize / size_x)
         size_y = size_x
-    return np.float32(buffer.reshape((n_elems, size_y, size_x, gdal_ds.RasterCount)))
+    return buffer.reshape((n_elems, size_y, size_x, gdal_ds.RasterCount))
 
 
 # -------------------------------------------------- Buffer class ------------------------------------------------------
@@ -244,8 +247,11 @@ class PatchesImagesReader(PatchesReaderBase):
 
     @staticmethod
     def _read_extract_as_np_arr(gdal_ds, offset):
+        gdal_to_np_types = {1: 'uint8', 2: 'uint16', 3: 'int16', 4: 'uint32', 5: 'int32', 6: 'float32', 7: 'float64',
+                            10: 'complex64', 11: 'complex128'}
         assert gdal_ds is not None
         psz = gdal_ds.RasterXSize
+        gdal_type = gdal_ds.GetRasterBand(1).DataType
         yoff = int(offset * psz)
         assert yoff + psz <= gdal_ds.RasterYSize
         buffer = gdal_ds.ReadAsArray(0, yoff, psz, psz)
@@ -254,7 +260,7 @@ class PatchesImagesReader(PatchesReaderBase):
         else:  # single-band raster
             buffer = np.expand_dims(buffer, axis=2)
 
-        return np.float32(buffer)
+        return buffer.astype(gdal_to_np_types[gdal_type])
 
     def get_sample(self, index):
         """
-- 
GitLab