From 86b38ce38b4790759922cad02bd4d5d56e86d2b6 Mon Sep 17 00:00:00 2001 From: Narcon Nicolas <nicolas.narcon@inrae.fr> Date: Wed, 23 Mar 2022 15:08:11 +0100 Subject: [PATCH 01/17] ENH: add TFRecord class + to_tfrecords method + split "array" data from "scalar" data + add max_nb_of_samples for Dataset --- python/otbtf.py | 241 +++++++++++++++++++++++++++++++++++++++++++++-- python/system.py | 68 +++++++++++++ 2 files changed, 301 insertions(+), 8 deletions(-) create mode 100644 python/system.py diff --git a/python/otbtf.py b/python/otbtf.py index a23d5237..45e13791 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -18,17 +18,23 @@ # # ==========================================================================*/ """ -Contains stuff to help working with TensorFlow and geospatial data in the -OTBTF framework. +Contains stuff to help working with TensorFlow and geospatial data in the OTBTF framework. """ +import glob +import json +import os import threading import multiprocessing import time import logging from abc import ABC, abstractmethod +from functools import partial +from tqdm import tqdm + import numpy as np import tensorflow as tf import gdal +import system # ----------------------------------------------------- Helpers -------------------------------------------------------- @@ -167,13 +173,18 @@ class PatchesImagesReader(PatchesReaderBase): :see PatchesReaderBase """ - def __init__(self, filenames_dict: dict, use_streaming=False): + def __init__(self, filenames_dict, scalar_dict=None, use_streaming=False): """ :param filenames_dict: A dict() structured as follow: {src_name1: [src1_patches_image_1.tif, ..., src1_patches_image_N.tif], src_name2: [src2_patches_image_1.tif, ..., src2_patches_image_N.tif], ... src_nameM: [srcM_patches_image_1.tif, ..., srcM_patches_image_N.tif]} + :param scalar_dict: (optional) a dict containing list of scalars (int, float, str) as follow: + {scalar_name1: ["value_1", ..., "value_N"], + scalar_name2: [value_1, ..., value_N], + ... + scalar_nameN: [value1, ..., value_N]} :param use_streaming: if True, the patches are read on the fly from the disc, nothing is kept in memory. """ @@ -182,8 +193,13 @@ class PatchesImagesReader(PatchesReaderBase): # gdal_ds dict self.gdal_ds = {key: [gdal_open(src_fn) for src_fn in src_fns] for key, src_fns in filenames_dict.items()} + # Scalar parameters (e.g. metadatas) + self.scalar_dict = scalar_dict + if scalar_dict is None: + self.scalar_dict = {} + # check number of patches in each sources - if len({len(ds_list) for ds_list in self.gdal_ds.values()}) != 1: + if len({len(ds_list) for ds_list in list(self.gdal_ds.values()) + list(self.scalar_dict.values())}) != 1: raise Exception("Each source must have the same number of patches images") # streaming on/off @@ -213,6 +229,12 @@ class PatchesImagesReader(PatchesReaderBase): if not self.use_streaming: patches_list = {src_key: [read_as_np_arr(ds) for ds in self.gdal_ds[src_key]] for src_key in self.gdal_ds} self.patches_buffer = {src_key: np.concatenate(patches_list[src_key], axis=0) for src_key in self.gdal_ds} + # Create a scalars dict so that one scalar <-> one patch + self.scalar_buffer = {} + for src_key, scalars in self.scalar_dict.items(): + self.scalar_buffer[src_key] = [] + for scalar, ds_size in zip(scalars, self.ds_sizes): + self.scalar_buffer[src_key].extend([scalar] * ds_size) def _get_ds_and_offset_from_index(self, index): offset = index @@ -254,9 +276,11 @@ class PatchesImagesReader(PatchesReaderBase): if not self.use_streaming: res = {src_key: self.patches_buffer[src_key][index, :, :, :] for src_key in self.gdal_ds} + res.update({key: np.asarray(scalars[index]) for key, scalars in self.scalar_buffer.items()}) else: i, offset = self._get_ds_and_offset_from_index(index) res = {src_key: self._read_extract_as_np_arr(self.gdal_ds[src_key][i], offset) for src_key in self.gdal_ds} + res.update({key: np.asarray(scalars[i]) for key, scalars in self.scalar_dict.items()}) return res @@ -362,16 +386,18 @@ class Dataset: """ def __init__(self, patches_reader: PatchesReaderBase, buffer_length: int = 128, - Iterator: IteratorBase = RandomIterator): + Iterator=RandomIterator, max_nb_of_samples=None): """ :param patches_reader: The patches reader instance :param buffer_length: The number of samples that are stored in the buffer :param Iterator: The iterator class used to generate the sequence of patches indices. + :param max_nb_of_samples: Optional, max number of samples to consider """ # patches reader self.patches_reader = patches_reader - self.size = self.patches_reader.get_size() + self.size = min(self.patches_reader.get_size(), + max_nb_of_samples) if max_nb_of_samples else self.patches_reader.get_size() # iterator self.iterator = Iterator(patches_reader=self.patches_reader) @@ -380,6 +406,7 @@ class Dataset: self.output_types = dict() self.output_shapes = dict() one_sample = self.patches_reader.get_sample(index=0) + print(one_sample) # DEBUG for src_key, np_arr in one_sample.items(): self.output_shapes[src_key] = np_arr.shape self.output_types[src_key] = tf.dtypes.as_dtype(np_arr.dtype) @@ -404,6 +431,14 @@ class Dataset: output_types=self.output_types, output_shapes=self.output_shapes).repeat(1) + def to_tfrecords(self, output_dir, n_samples_per_shard=100, drop_remainder=True): + """ + + """ + tfrecord = TFRecords(output_dir) + tfrecord.ds2tfrecord(self, n_samples_per_shard=n_samples_per_shard, drop_remainder=drop_remainder) + + def get_stats(self) -> dict: """ :return: the dataset statistics, computed by the patches reader @@ -502,8 +537,8 @@ class DatasetFromPatchesImages(Dataset): :see Dataset """ - def __init__(self, filenames_dict: dict, use_streaming: bool = False, buffer_length: int = 128, - Iterator: IteratorBase = RandomIterator): + def __init__(self, filenames_dict, use_streaming=False, buffer_length: int = 128, + Iterator=RandomIterator): """ :param filenames_dict: A dict() structured as follow: {src_name1: [src1_patches_image1, ..., src1_patches_imageN1], @@ -518,3 +553,193 @@ class DatasetFromPatchesImages(Dataset): patches_reader = PatchesImagesReader(filenames_dict=filenames_dict, use_streaming=use_streaming) super().__init__(patches_reader=patches_reader, buffer_length=buffer_length, Iterator=Iterator) + + +class TFRecords: + """ + This class allows to convert Dataset objects to TFRecords and to load them in dataset tensorflows format. + """ + + def __init__(self, path): + """ + :param path: Can be a directory where TFRecords must be saved/loaded or a single TFRecord path + """ + if system.is_dir(path) or not os.path.exists(path): + self.dirpath = path + system.mkdir(self.dirpath) + self.tfrecords_pattern_path = "{}*.records".format(system.pathify(self.dirpath)) + else: + self.dirpath = system.dirname(path) + self.tfrecords_pattern_path = path + self.output_types_file = "{}output_types.json".format(system.pathify(self.dirpath)) + self.output_shape_file = "{}output_shape.json".format(system.pathify(self.dirpath)) + self.output_shape = self.load(self.output_shape_file) if os.path.exists(self.output_shape_file) else None + self.output_types = self.load(self.output_types_file) if os.path.exists(self.output_types_file) else None + + def _bytes_feature(self, value): + """ + Used to convert a value to a type compatible with tf.train.Example. + :param value: value + :return a bytes_list from a string / byte. + """ + if isinstance(value, type(tf.constant(0))): + value = value.numpy() # BytesList won't unpack a string from an EagerTensor. + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) + + def ds2tfrecord(self, dataset, n_samples_per_shard=100, drop_remainder=True): + """ + Convert and save samples from dataset object to tfrecord files. + :param dataset: Dataset object to convert into a set of tfrecords + :param n_samples_per_shard: Number of samples per shard + :param drop_remainder: Whether additional samples should be dropped. Advisable if using multiworkers training. + If True, all TFRecords will have `n_samples_per_shard` samples + """ + logging.info("%s samples", dataset.size) + + nb_shards = (dataset.size // n_samples_per_shard) + if not drop_remainder and dataset.size % n_samples_per_shard > 0: + nb_shards += 1 + + self.convert_dataset_output_shapes(dataset) + + def _convert_data(data): + """ + Convert data + """ + data_converted = {} + + for k, d in data.items(): + data_converted[k] = d.name + + return data_converted + + self.save(_convert_data(dataset.output_types), self.output_types_file) + + for i in tqdm(range(nb_shards)): + + if (i + 1) * n_samples_per_shard <= dataset.size: + nb_sample = n_samples_per_shard + else: + nb_sample = dataset.size - i * n_samples_per_shard + + filepath = "{}{}.records".format(system.pathify(self.dirpath), i) + with tf.io.TFRecordWriter(filepath) as writer: + for s in range(nb_sample): + sample = dataset.read_one_sample() + serialized_sample = {name: tf.io.serialize_tensor(fea) for name, fea in sample.items()} + features = {name: self._bytes_feature(serialized_tensor) for name, serialized_tensor in + serialized_sample.items()} + tf_features = tf.train.Features(feature=features) + example = tf.train.Example(features=tf_features) + writer.write(example.SerializeToString()) + + @staticmethod + def save(data, filepath): + """ + Save data to pickle format. + :param data: Data to save json format + :param filepath: Output file name + """ + + with open(filepath, 'w') as f: + json.dump(data, f, indent=4) + + @staticmethod + def load(filepath): + """ + Return data from pickle format. + :param filepath: Input file name + """ + with open(filepath, 'r') as f: + return json.load(f) + + def convert_dataset_output_shapes(self, dataset): + """ + Convert and save numpy shape to tensorflow shape. + :param dataset: Dataset object containing output shapes + """ + output_shapes = {} + + for key in dataset.output_shapes.keys(): + output_shapes[key] = (None,) + dataset.output_shapes[key] + + self.save(output_shapes, self.output_shape_file) + + @staticmethod + def parse_tfrecord(example, features_types, target_keys): + """ + Parse example object to sample dict. + :param example: Example object to parse + :param features_types: List of types for each feature + :param target_keys: list of keys of the targets + """ + read_features = {key: tf.io.FixedLenFeature([], dtype=tf.string) for key in features_types} + example_parsed = tf.io.parse_single_example(example, read_features) + + for key in read_features.keys(): + example_parsed[key] = tf.io.parse_tensor(example_parsed[key], out_type=features_types[key]) + + # Differentiating inputs and outputs + input_parsed = {key: value for (key, value) in example_parsed.items() if key not in target_keys} + target_parsed = {key: value for (key, value) in example_parsed.items() if key in target_keys} + + return input_parsed, target_parsed + + + def read(self, batch_size, target_keys, n_workers=1, drop_remainder=True, shuffle_buffer_size=None): + """ + Read all tfrecord files matching with pattern and convert data to tensorflow dataset. + :param batch_size: Size of tensorflow batch + :param target_key: Key of the target, e.g. 's2_out' + :param n_workers: number of workers, e.g. 4 if using 4 GPUs + e.g. 12 if using 3 nodes of 4 GPUs + :param drop_remainder: whether the last batch should be dropped in the case it has fewer than + `batch_size` elements. True is advisable when training on multiworkers. + False is advisable when evaluating metrics so that all samples are used + :param shuffle_buffer_size: is None, shuffle is not used. Else, blocks of shuffle_buffer_size + elements are shuffled using uniform random. + """ + options = tf.data.Options() + if shuffle_buffer_size: + options.experimental_deterministic = False # disable order, increase speed + options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.AUTO # for multiworker + parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys) + + # TODO: to be investigated : + # 1/ num_parallel_reads useful ? I/O bottleneck of not ? + # 2/ num_parallel_calls=tf.data.experimental.AUTOTUNE useful ? + # 3/ shuffle or not shuffle ? + matching_files = glob.glob(self.tfrecords_pattern_path) + logging.info('Searching TFRecords in %s...', self.tfrecords_pattern_path) + logging.info('Number of matching TFRecords: %s', len(matching_files)) + matching_files = matching_files[:n_workers * (len(matching_files) // n_workers)] # files multiple of workers + nb_matching_files = len(matching_files) + if nb_matching_files == 0: + raise Exception("At least one worker has no TFRecord file in {}. Please ensure that the number of TFRecord " + "files is greater or equal than the number of workers!".format(self.tfrecords_pattern_path)) + logging.info('Reducing number of records to : %s', nb_matching_files) + dataset = tf.data.TFRecordDataset(matching_files) # , num_parallel_reads=2) # interleaves reads from xxx files + dataset = dataset.with_options(options) # uses data as soon as it streams in, rather than in its original order + dataset = dataset.map(parse, num_parallel_calls=tf.data.experimental.AUTOTUNE) + if shuffle_buffer_size: + dataset = dataset.shuffle(buffer_size=shuffle_buffer_size) + dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) + dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) + # TODO voir si on met le prefetch avant le batch cf https://keras.io/examples/keras_recipes/tfrecord/ + + return dataset + + def read_one_sample(self, target_keys): + """ + Read one tfrecord file matching with pattern and convert data to tensorflow dataset. + :param target_key: Key of the target, e.g. 's2_out' + """ + matching_files = glob.glob(self.tfrecords_pattern_path) + one_file = matching_files[0] + parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys) + dataset = tf.data.TFRecordDataset(one_file) + dataset = dataset.map(parse) + dataset = dataset.batch(1) + + sample = iter(dataset).get_next() + return sample diff --git a/python/system.py b/python/system.py new file mode 100644 index 00000000..e7b581fd --- /dev/null +++ b/python/system.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +""" +Copyright (c) 2020-2022 INRAE + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" +"""Various system operations""" +import logging +import pathlib +import os + +# ---------------------------------------------------- Helpers --------------------------------------------------------- + +def pathify(pth): + """ Adds posix separator if needed """ + if not pth.endswith("/"): + pth += "/" + return pth + + +def mkdir(pth): + """ Create a directory """ + path = pathlib.Path(pth) + path.mkdir(parents=True, exist_ok=True) + + +def dirname(filename): + """ Returns the parent directory of the file """ + return str(pathlib.Path(filename).parent) + + +def basic_logging_init(): + """ basic logging initialization """ + logging.basicConfig( + format='%(asctime)s %(levelname)-8s %(message)s', + level=logging.INFO, + datefmt='%Y-%m-%d %H:%M:%S') + + +def logging_info(msg, verbose=True): + """ + Prints log info only if required by `verbose` + :param msg: message to log + :param verbose: boolean. Whether to log msg or not. Default True + :return: + """ + if verbose: + logging.info(msg) + +def is_dir(filename): + """ return True if filename is the path to a directory """ + return os.path.isdir(filename) -- GitLab From 28e489468ece5134fd441df1609a8ded77a95531 Mon Sep 17 00:00:00 2001 From: Narcon Nicolas <nicolas.narcon@inrae.fr> Date: Thu, 24 Mar 2022 16:43:46 +0100 Subject: [PATCH 02/17] ENH: split the Dataset initialisation and the patch_reader feeding --- python/otbtf.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/python/otbtf.py b/python/otbtf.py index 45e13791..8f8f9359 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -385,7 +385,7 @@ class Dataset: :see Buffer """ - def __init__(self, patches_reader: PatchesReaderBase, buffer_length: int = 128, + def __init__(self, patches_reader: PatchesReaderBase = None, buffer_length: int = 128, Iterator=RandomIterator, max_nb_of_samples=None): """ :param patches_reader: The patches reader instance @@ -393,8 +393,19 @@ class Dataset: :param Iterator: The iterator class used to generate the sequence of patches indices. :param max_nb_of_samples: Optional, max number of samples to consider """ - # patches reader + if patches_reader: + self.feed(patches_reader, buffer_length, Iterator, max_nb_of_samples) + + + def feed(self, patches_reader: PatchesReaderBase = None, buffer_length: int = 128, + Iterator=RandomIterator, max_nb_of_samples=None): + """ + :param patches_reader: The patches reader instance + :param buffer_length: The number of samples that are stored in the buffer + :param Iterator: The iterator class used to generate the sequence of patches indices. + :param max_nb_of_samples: Optional, max number of samples to consider + """ self.patches_reader = patches_reader self.size = min(self.patches_reader.get_size(), max_nb_of_samples) if max_nb_of_samples else self.patches_reader.get_size() -- GitLab From c774e9b62687d07abeb9a8829aace58a2da8a0c0 Mon Sep 17 00:00:00 2001 From: Narcon Nicolas <nicolas.narcon@inrae.fr> Date: Wed, 30 Mar 2022 14:48:09 +0200 Subject: [PATCH 03/17] ENH: add more log info --- python/otbtf.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/python/otbtf.py b/python/otbtf.py index 8f8f9359..17fd06c0 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -407,8 +407,14 @@ class Dataset: :param max_nb_of_samples: Optional, max number of samples to consider """ self.patches_reader = patches_reader - self.size = min(self.patches_reader.get_size(), - max_nb_of_samples) if max_nb_of_samples else self.patches_reader.get_size() + + # If necessary, limit the nb of samples + logging.info('There are %s samples available', self.patches_reader.get_size()) + if max_nb_of_samples and self.patches_reader.get_size() > max_nb_of_samples: + logging.info('Reducing number of samples to %s', max_nb_of_samples) + self.size = max_nb_of_samples + else: + self.size = self.patches_reader.get_size() # iterator self.iterator = Iterator(patches_reader=self.patches_reader) -- GitLab From 120a7ba69b75400dc56c8aa70ff5902cb8a91cc0 Mon Sep 17 00:00:00 2001 From: Narcon Nicolas <nicolas.narcon@inrae.fr> Date: Mon, 4 Apr 2022 16:36:03 +0200 Subject: [PATCH 04/17] FIX: remove API breaker --- python/otbtf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/otbtf.py b/python/otbtf.py index 17fd06c0..d13d5788 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -173,19 +173,19 @@ class PatchesImagesReader(PatchesReaderBase): :see PatchesReaderBase """ - def __init__(self, filenames_dict, scalar_dict=None, use_streaming=False): + def __init__(self, filenames_dict, use_streaming=False, scalar_dict=None): """ :param filenames_dict: A dict() structured as follow: {src_name1: [src1_patches_image_1.tif, ..., src1_patches_image_N.tif], src_name2: [src2_patches_image_1.tif, ..., src2_patches_image_N.tif], ... src_nameM: [srcM_patches_image_1.tif, ..., srcM_patches_image_N.tif]} + :param use_streaming: if True, the patches are read on the fly from the disc, nothing is kept in memory. :param scalar_dict: (optional) a dict containing list of scalars (int, float, str) as follow: {scalar_name1: ["value_1", ..., "value_N"], scalar_name2: [value_1, ..., value_N], ... - scalar_nameN: [value1, ..., value_N]} - :param use_streaming: if True, the patches are read on the fly from the disc, nothing is kept in memory. + scalar_nameM: [value1, ..., valueN]} """ assert len(filenames_dict.values()) > 0 -- GitLab From 817a69c7806e7dda1976916339743c0067d78304 Mon Sep 17 00:00:00 2001 From: Narcon Nicolas <nicolas.narcon@inrae.fr> Date: Mon, 4 Apr 2022 16:55:39 +0200 Subject: [PATCH 05/17] FIX: add tqdm dependency --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 34b9b4a4..ece649b3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -25,7 +25,7 @@ RUN if $GUI; then \ RUN ln -s /usr/bin/python3 /usr/local/bin/python && ln -s /usr/bin/pip3 /usr/local/bin/pip # NumPy version is conflicting with system's gdal dep and may require venv ARG NUMPY_SPEC="==1.19.*" -RUN pip install --no-cache-dir -U pip wheel mock six future deprecated "numpy$NUMPY_SPEC" \ +RUN pip install --no-cache-dir -U pip wheel mock six future tqdm deprecated "numpy$NUMPY_SPEC" \ && pip install --no-cache-dir --no-deps keras_applications keras_preprocessing # ---------------------------------------------------------------------------- -- GitLab From 0da55b2105cb256d2f6212ebe06033002a8c0164 Mon Sep 17 00:00:00 2001 From: Narcon Nicolas <nicolas.narcon@inrae.fr> Date: Mon, 4 Apr 2022 17:02:24 +0200 Subject: [PATCH 06/17] REFAC: replace system by plain python os library --- python/otbtf.py | 15 +++++------ python/system.py | 68 ------------------------------------------------ 2 files changed, 7 insertions(+), 76 deletions(-) delete mode 100644 python/system.py diff --git a/python/otbtf.py b/python/otbtf.py index d13d5788..c702375e 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -34,7 +34,6 @@ from tqdm import tqdm import numpy as np import tensorflow as tf import gdal -import system # ----------------------------------------------------- Helpers -------------------------------------------------------- @@ -581,15 +580,15 @@ class TFRecords: """ :param path: Can be a directory where TFRecords must be saved/loaded or a single TFRecord path """ - if system.is_dir(path) or not os.path.exists(path): + if os.path.isdir(path) or not os.path.exists(path): self.dirpath = path - system.mkdir(self.dirpath) - self.tfrecords_pattern_path = "{}*.records".format(system.pathify(self.dirpath)) + os.makedirs(self.dirpath, exist_ok=True) + self.tfrecords_pattern_path = os.path.join(self.dirpath, "*.records") else: - self.dirpath = system.dirname(path) + self.dirpath = os.path.dirname(path) self.tfrecords_pattern_path = path - self.output_types_file = "{}output_types.json".format(system.pathify(self.dirpath)) - self.output_shape_file = "{}output_shape.json".format(system.pathify(self.dirpath)) + self.output_types_file = os.path.join(self.dirpath, "output_types.json") + self.output_shape_file = os.path.join(self.dirpath, "output_shape.json") self.output_shape = self.load(self.output_shape_file) if os.path.exists(self.output_shape_file) else None self.output_types = self.load(self.output_types_file) if os.path.exists(self.output_types_file) else None @@ -639,7 +638,7 @@ class TFRecords: else: nb_sample = dataset.size - i * n_samples_per_shard - filepath = "{}{}.records".format(system.pathify(self.dirpath), i) + filepath = os.path.join(self.dirpath, f"{i}.records") with tf.io.TFRecordWriter(filepath) as writer: for s in range(nb_sample): sample = dataset.read_one_sample() diff --git a/python/system.py b/python/system.py deleted file mode 100644 index e7b581fd..00000000 --- a/python/system.py +++ /dev/null @@ -1,68 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Copyright (c) 2020-2022 INRAE - -Permission is hereby granted, free of charge, to any person obtaining a -copy of this software and associated documentation files (the "Software"), -to deal in the Software without restriction, including without limitation -the rights to use, copy, modify, merge, publish, distribute, sublicense, -and/or sell copies of the Software, and to permit persons to whom the -Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. -""" -"""Various system operations""" -import logging -import pathlib -import os - -# ---------------------------------------------------- Helpers --------------------------------------------------------- - -def pathify(pth): - """ Adds posix separator if needed """ - if not pth.endswith("/"): - pth += "/" - return pth - - -def mkdir(pth): - """ Create a directory """ - path = pathlib.Path(pth) - path.mkdir(parents=True, exist_ok=True) - - -def dirname(filename): - """ Returns the parent directory of the file """ - return str(pathlib.Path(filename).parent) - - -def basic_logging_init(): - """ basic logging initialization """ - logging.basicConfig( - format='%(asctime)s %(levelname)-8s %(message)s', - level=logging.INFO, - datefmt='%Y-%m-%d %H:%M:%S') - - -def logging_info(msg, verbose=True): - """ - Prints log info only if required by `verbose` - :param msg: message to log - :param verbose: boolean. Whether to log msg or not. Default True - :return: - """ - if verbose: - logging.info(msg) - -def is_dir(filename): - """ return True if filename is the path to a directory """ - return os.path.isdir(filename) -- GitLab From e82173f7268ba63dea90510b2503fd72fb7fa1e1 Mon Sep 17 00:00:00 2001 From: Narcon Nicolas <nicolas.narcon@inrae.fr> Date: Tue, 12 Apr 2022 12:24:38 +0200 Subject: [PATCH 07/17] FIX: make `_read_extract_as_np_arr` method return 3D arrays even for singleband --- python/otbtf.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/otbtf.py b/python/otbtf.py index c702375e..71e2f6a3 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -258,6 +258,9 @@ class PatchesImagesReader(PatchesReaderBase): buffer = gdal_ds.ReadAsArray(0, yoff, psz, psz) if len(buffer.shape) == 3: buffer = np.transpose(buffer, axes=(1, 2, 0)) + else: # single-band raster + buffer = np.expand_dims(buffer, axis=2) + return np.float32(buffer) def get_sample(self, index): -- GitLab From 4ab9bdb59daae0f6adfbec9990bd4ff9b970398f Mon Sep 17 00:00:00 2001 From: Narcon Nicolas <nicolas.narcon@inrae.fr> Date: Tue, 19 Apr 2022 16:41:39 +0200 Subject: [PATCH 08/17] ENH: add the possibility to specify cropping of the target when reading TFRecords --- python/otbtf.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/python/otbtf.py b/python/otbtf.py index 61a0767f..598eab65 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -685,12 +685,13 @@ class TFRecords: self.save(output_shapes, self.output_shape_file) @staticmethod - def parse_tfrecord(example, features_types, target_keys): + def parse_tfrecord(example, features_types, target_keys, target_cropping=None): """ Parse example object to sample dict. :param example: Example object to parse :param features_types: List of types for each feature :param target_keys: list of keys of the targets + :param target_cropping: Optional. Number of pixels to be removed on each side of the target tensor. """ read_features = {key: tf.io.FixedLenFeature([], dtype=tf.string) for key in features_types} example_parsed = tf.io.parse_single_example(example, read_features) @@ -702,27 +703,33 @@ class TFRecords: input_parsed = {key: value for (key, value) in example_parsed.items() if key not in target_keys} target_parsed = {key: value for (key, value) in example_parsed.items() if key in target_keys} + if target_cropping: + print({key: value for key, value in target_parsed.items()}) + target_parsed = {key: value[target_cropping:-target_cropping, target_cropping:-target_cropping, :] for key, value in target_parsed.items()} + return input_parsed, target_parsed - def read(self, batch_size, target_keys, n_workers=1, drop_remainder=True, shuffle_buffer_size=None): + def read(self, batch_size, target_keys, target_cropping=None, n_workers=1, drop_remainder=True, shuffle_buffer_size=None): """ Read all tfrecord files matching with pattern and convert data to tensorflow dataset. :param batch_size: Size of tensorflow batch - :param target_key: Key of the target, e.g. 's2_out' + :param target_keys: Keys of the target, e.g. ['s2_out'] + :param target_cropping: Number of pixels to be removed on each side of the target. Must be used with a network + architecture coherent with this, i.e. that has a Cropping2D layer in the end :param n_workers: number of workers, e.g. 4 if using 4 GPUs e.g. 12 if using 3 nodes of 4 GPUs :param drop_remainder: whether the last batch should be dropped in the case it has fewer than `batch_size` elements. True is advisable when training on multiworkers. False is advisable when evaluating metrics so that all samples are used - :param shuffle_buffer_size: is None, shuffle is not used. Else, blocks of shuffle_buffer_size + :param shuffle_buffer_size: if None, shuffle is not used. Else, blocks of shuffle_buffer_size elements are shuffled using uniform random. """ options = tf.data.Options() if shuffle_buffer_size: options.experimental_deterministic = False # disable order, increase speed options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.AUTO # for multiworker - parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys) + parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys, target_cropping=target_cropping) # TODO: to be investigated : # 1/ num_parallel_reads useful ? I/O bottleneck of not ? -- GitLab From 31a4f931a29899431e62cbaabd6d86747c3ecd37 Mon Sep 17 00:00:00 2001 From: Narcon Nicolas <nicolas.narcon@inrae.fr> Date: Tue, 19 Apr 2022 16:53:59 +0200 Subject: [PATCH 09/17] STYLE: remove debug prints --- python/otbtf.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/otbtf.py b/python/otbtf.py index 598eab65..c944728c 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -425,7 +425,6 @@ class Dataset: self.output_types = dict() self.output_shapes = dict() one_sample = self.patches_reader.get_sample(index=0) - print(one_sample) # DEBUG for src_key, np_arr in one_sample.items(): self.output_shapes[src_key] = np_arr.shape self.output_types[src_key] = tf.dtypes.as_dtype(np_arr.dtype) @@ -704,7 +703,6 @@ class TFRecords: target_parsed = {key: value for (key, value) in example_parsed.items() if key in target_keys} if target_cropping: - print({key: value for key, value in target_parsed.items()}) target_parsed = {key: value[target_cropping:-target_cropping, target_cropping:-target_cropping, :] for key, value in target_parsed.items()} return input_parsed, target_parsed -- GitLab From c3eb4703503a45e5cc8b9a4ea409c7341d46558d Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 10:49:10 +0200 Subject: [PATCH 10/17] ADD: modifications --- python/otbtf.py | 74 ++++++++++++++++++------------------------------- 1 file changed, 27 insertions(+), 47 deletions(-) diff --git a/python/otbtf.py b/python/otbtf.py index c944728c..922d13db 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -172,7 +172,7 @@ class PatchesImagesReader(PatchesReaderBase): :see PatchesReaderBase """ - def __init__(self, filenames_dict, use_streaming=False, scalar_dict=None): + def __init__(self, filenames_dict, use_streaming=False, scalar_dict={}): """ :param filenames_dict: A dict() structured as follow: {src_name1: [src1_patches_image_1.tif, ..., src1_patches_image_N.tif], @@ -192,18 +192,18 @@ class PatchesImagesReader(PatchesReaderBase): # gdal_ds dict self.gdal_ds = {key: [gdal_open(src_fn) for src_fn in src_fns] for key, src_fns in filenames_dict.items()} - # Scalar parameters (e.g. metadatas) - self.scalar_dict = scalar_dict - if scalar_dict is None: - self.scalar_dict = {} + # streaming on/off + self.use_streaming = use_streaming + + # Scalar dict (e.g. for metadata) + # If the scalars are not numpy.ndarray, convert them + self.scalar_dict = {key: [i if isinstance(i, np.ndarray) else np.asarray(i) for i in scalars] + for key, scalars in scalar_dict.items()} # check number of patches in each sources if len({len(ds_list) for ds_list in list(self.gdal_ds.values()) + list(self.scalar_dict.values())}) != 1: raise Exception("Each source must have the same number of patches images") - # streaming on/off - self.use_streaming = use_streaming - # gdal_ds check nb_of_patches = {key: 0 for key in self.gdal_ds} self.nb_of_channels = dict() @@ -226,14 +226,8 @@ class PatchesImagesReader(PatchesReaderBase): # if use_streaming is False, we store in memory all patches images if not self.use_streaming: - patches_list = {src_key: [read_as_np_arr(ds) for ds in self.gdal_ds[src_key]] for src_key in self.gdal_ds} - self.patches_buffer = {src_key: np.concatenate(patches_list[src_key], axis=0) for src_key in self.gdal_ds} - # Create a scalars dict so that one scalar <-> one patch - self.scalar_buffer = {} - for src_key, scalars in self.scalar_dict.items(): - self.scalar_buffer[src_key] = [] - for scalar, ds_size in zip(scalars, self.ds_sizes): - self.scalar_buffer[src_key].extend([scalar] * ds_size) + self.patches_buffer = {src_key: np.concatenate([read_as_np_arr(ds) for ds in src_ds[src_key]], axis=0) for + src_key, src_ds in self.gdal_ds.items()} def _get_ds_and_offset_from_index(self, index): offset = index @@ -276,20 +270,19 @@ class PatchesImagesReader(PatchesReaderBase): assert index >= 0 assert index < self.size + i, offset = self._get_ds_and_offset_from_index(index) + res = {src_key: scalar[i] for src_key, scalar in self.scalar_dict.items()} if not self.use_streaming: - res = {src_key: self.patches_buffer[src_key][index, :, :, :] for src_key in self.gdal_ds} - res.update({key: np.asarray(scalars[index]) for key, scalars in self.scalar_buffer.items()}) + res.update({src_key: arr[index, :, :, :] for src_key, arr in self.patches_buffer.items()}) else: - i, offset = self._get_ds_and_offset_from_index(index) - res = {src_key: self._read_extract_as_np_arr(self.gdal_ds[src_key][i], offset) for src_key in self.gdal_ds} - res.update({key: np.asarray(scalars[i]) for key, scalars in self.scalar_dict.items()}) - + res.update({src_key: self._read_extract_as_np_arr(self.gdal_ds[src_key][i], offset) + for src_key in self.gdal_ds}) return res def get_stats(self): """ Compute some statistics for each source. - Depending if streaming is used, the statistics are computed directly in memory, or chunk-by-chunk. + When streaming is used, chunk-by-chunk. Else, the statistics are computed directly in memory. :return statistics dict """ @@ -340,6 +333,7 @@ class IteratorBase(ABC): """ Base class for iterators """ + @abstractmethod def __init__(self, patches_reader: PatchesReaderBase): pass @@ -396,22 +390,10 @@ class Dataset: :param max_nb_of_samples: Optional, max number of samples to consider """ # patches reader - if patches_reader: - self.feed(patches_reader, buffer_length, Iterator, max_nb_of_samples) - - - def feed(self, patches_reader: PatchesReaderBase = None, buffer_length: int = 128, - Iterator=RandomIterator, max_nb_of_samples=None): - """ - :param patches_reader: The patches reader instance - :param buffer_length: The number of samples that are stored in the buffer - :param Iterator: The iterator class used to generate the sequence of patches indices. - :param max_nb_of_samples: Optional, max number of samples to consider - """ self.patches_reader = patches_reader # If necessary, limit the nb of samples - logging.info('There are %s samples available', self.patches_reader.get_size()) + logging.info('Number of samples: %s', self.patches_reader.get_size()) if max_nb_of_samples and self.patches_reader.get_size() > max_nb_of_samples: logging.info('Reducing number of samples to %s', max_nb_of_samples) self.size = max_nb_of_samples @@ -451,14 +433,19 @@ class Dataset: def to_tfrecords(self, output_dir, n_samples_per_shard=100, drop_remainder=True): """ + Save the dataset into TFRecord files + :param output_dir: output directory + :param n_samples_per_shard: number of samples per TFRecord file + :param drop_remainder: drop remainder samples """ tfrecord = TFRecords(output_dir) tfrecord.ds2tfrecord(self, n_samples_per_shard=n_samples_per_shard, drop_remainder=drop_remainder) - def get_stats(self) -> dict: """ + Compute dataset statistics + :return: the dataset statistics, computed by the patches reader """ with self.mining_lock: @@ -684,13 +671,12 @@ class TFRecords: self.save(output_shapes, self.output_shape_file) @staticmethod - def parse_tfrecord(example, features_types, target_keys, target_cropping=None): + def parse_tfrecord(example, features_types, target_keys): """ Parse example object to sample dict. :param example: Example object to parse :param features_types: List of types for each feature :param target_keys: list of keys of the targets - :param target_cropping: Optional. Number of pixels to be removed on each side of the target tensor. """ read_features = {key: tf.io.FixedLenFeature([], dtype=tf.string) for key in features_types} example_parsed = tf.io.parse_single_example(example, read_features) @@ -702,19 +688,13 @@ class TFRecords: input_parsed = {key: value for (key, value) in example_parsed.items() if key not in target_keys} target_parsed = {key: value for (key, value) in example_parsed.items() if key in target_keys} - if target_cropping: - target_parsed = {key: value[target_cropping:-target_cropping, target_cropping:-target_cropping, :] for key, value in target_parsed.items()} - return input_parsed, target_parsed - - def read(self, batch_size, target_keys, target_cropping=None, n_workers=1, drop_remainder=True, shuffle_buffer_size=None): + def read(self, batch_size, target_keys, n_workers=1, drop_remainder=True, shuffle_buffer_size=None): """ Read all tfrecord files matching with pattern and convert data to tensorflow dataset. :param batch_size: Size of tensorflow batch :param target_keys: Keys of the target, e.g. ['s2_out'] - :param target_cropping: Number of pixels to be removed on each side of the target. Must be used with a network - architecture coherent with this, i.e. that has a Cropping2D layer in the end :param n_workers: number of workers, e.g. 4 if using 4 GPUs e.g. 12 if using 3 nodes of 4 GPUs :param drop_remainder: whether the last batch should be dropped in the case it has fewer than @@ -727,7 +707,7 @@ class TFRecords: if shuffle_buffer_size: options.experimental_deterministic = False # disable order, increase speed options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.AUTO # for multiworker - parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys, target_cropping=target_cropping) + parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys) # TODO: to be investigated : # 1/ num_parallel_reads useful ? I/O bottleneck of not ? -- GitLab From 5186eb5e0d6a771f437e7666b576fee7769c521c Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 11:38:47 +0200 Subject: [PATCH 11/17] ENH: use default arg as None instead {} --- python/otbtf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/otbtf.py b/python/otbtf.py index 922d13db..ad77b7a2 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -172,7 +172,7 @@ class PatchesImagesReader(PatchesReaderBase): :see PatchesReaderBase """ - def __init__(self, filenames_dict, use_streaming=False, scalar_dict={}): + def __init__(self, filenames_dict, use_streaming=False, scalar_dict=None): """ :param filenames_dict: A dict() structured as follow: {src_name1: [src1_patches_image_1.tif, ..., src1_patches_image_N.tif], @@ -198,7 +198,7 @@ class PatchesImagesReader(PatchesReaderBase): # Scalar dict (e.g. for metadata) # If the scalars are not numpy.ndarray, convert them self.scalar_dict = {key: [i if isinstance(i, np.ndarray) else np.asarray(i) for i in scalars] - for key, scalars in scalar_dict.items()} + for key, scalars in scalar_dict.items()} if scalar_dict else {} # check number of patches in each sources if len({len(ds_list) for ds_list in list(self.gdal_ds.values()) + list(self.scalar_dict.values())}) != 1: -- GitLab From 878dde7dd0ab49a28d18acf9ad5907e1a4e70f70 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Wed, 20 Apr 2022 11:44:40 +0200 Subject: [PATCH 12/17] REFAC: change import order --- python/otbtf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/otbtf.py b/python/otbtf.py index ad77b7a2..cbb96e55 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -29,11 +29,10 @@ import time import logging from abc import ABC, abstractmethod from functools import partial -from tqdm import tqdm - import numpy as np import tensorflow as tf from osgeo import gdal +from tqdm import tqdm # ----------------------------------------------------- Helpers -------------------------------------------------------- @@ -581,9 +580,10 @@ class TFRecords: self.output_shape = self.load(self.output_shape_file) if os.path.exists(self.output_shape_file) else None self.output_types = self.load(self.output_types_file) if os.path.exists(self.output_types_file) else None + @staticmethod def _bytes_feature(self, value): """ - Used to convert a value to a type compatible with tf.train.Example. + Convert a value to a type compatible with tf.train.Example. :param value: value :return a bytes_list from a string / byte. """ -- GitLab From d223155ef2b4b616fe9d4eaddc2d1d760f61fdba Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 14:25:40 +0200 Subject: [PATCH 13/17] FIX: list indices must be integers --- python/otbtf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/otbtf.py b/python/otbtf.py index cbb96e55..83cb4fe5 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -225,7 +225,7 @@ class PatchesImagesReader(PatchesReaderBase): # if use_streaming is False, we store in memory all patches images if not self.use_streaming: - self.patches_buffer = {src_key: np.concatenate([read_as_np_arr(ds) for ds in src_ds[src_key]], axis=0) for + self.patches_buffer = {src_key: np.concatenate([read_as_np_arr(ds) for ds in src_ds], axis=0) for src_key, src_ds in self.gdal_ds.items()} def _get_ds_and_offset_from_index(self, index): -- GitLab From 1a2a42d0eb5ff61ff7266d702d21ef9578b27d4e Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Thu, 21 Apr 2022 14:27:34 +0200 Subject: [PATCH 14/17] FIX: static function --- python/otbtf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/otbtf.py b/python/otbtf.py index 83cb4fe5..75d188ae 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -581,7 +581,7 @@ class TFRecords: self.output_types = self.load(self.output_types_file) if os.path.exists(self.output_types_file) else None @staticmethod - def _bytes_feature(self, value): + def _bytes_feature(value): """ Convert a value to a type compatible with tf.train.Example. :param value: value -- GitLab From c3717f3b86991d160af74ee09b7bf55a05371efb Mon Sep 17 00:00:00 2001 From: Narcon Nicolas <nicolas.narcon@inrae.fr> Date: Wed, 20 Apr 2022 12:08:33 +0200 Subject: [PATCH 15/17] (Cherrypick from 14) generalize cropping target to a preprocessing function --- python/otbtf.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/python/otbtf.py b/python/otbtf.py index 75d188ae..c180a7c4 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -672,11 +672,15 @@ class TFRecords: @staticmethod def parse_tfrecord(example, features_types, target_keys): + def parse_tfrecord(example, features_types, target_keys, preprocessing_fn=None, **kwargs): """ Parse example object to sample dict. :param example: Example object to parse :param features_types: List of types for each feature :param target_keys: list of keys of the targets + :param preprocessing_fn: Optional. A preprocessing function that takes input, target as args and returns + a tuple (input_preprocessed, target_preprocessed) + :param kwargs: some keywords arguments for preprocessing_fn """ read_features = {key: tf.io.FixedLenFeature([], dtype=tf.string) for key in features_types} example_parsed = tf.io.parse_single_example(example, read_features) @@ -688,9 +692,13 @@ class TFRecords: input_parsed = {key: value for (key, value) in example_parsed.items() if key not in target_keys} target_parsed = {key: value for (key, value) in example_parsed.items() if key in target_keys} + if preprocessing_fn: + input_parsed, target_parsed = preprocessing_fn(input_parsed, target_parsed, **kwargs) + return input_parsed, target_parsed - def read(self, batch_size, target_keys, n_workers=1, drop_remainder=True, shuffle_buffer_size=None): + def read(self, batch_size, target_keys, n_workers=1, drop_remainder=True, shuffle_buffer_size=None, + preprocessing_fn=None, **kwargs): """ Read all tfrecord files matching with pattern and convert data to tensorflow dataset. :param batch_size: Size of tensorflow batch @@ -702,12 +710,16 @@ class TFRecords: False is advisable when evaluating metrics so that all samples are used :param shuffle_buffer_size: if None, shuffle is not used. Else, blocks of shuffle_buffer_size elements are shuffled using uniform random. + :param preprocessing_fn: Optional. A preprocessing function that takes input, target as args and returns + a tuple (input_preprocessed, target_preprocessed) + :param kwargs: some keywords arguments for preprocessing_fn """ options = tf.data.Options() if shuffle_buffer_size: options.experimental_deterministic = False # disable order, increase speed options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.AUTO # for multiworker - parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys) + parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys, + preprocessing_fn=preprocessing_fn, **kwargs) # TODO: to be investigated : # 1/ num_parallel_reads useful ? I/O bottleneck of not ? -- GitLab From 161e33e8cdc9f7024cb5c081f0d0561fdc300b30 Mon Sep 17 00:00:00 2001 From: Vincent Delbar <vincent.delbar@latelescop.fr> Date: Thu, 21 Apr 2022 15:55:49 +0200 Subject: [PATCH 16/17] FIX: indented bloc --- python/otbtf.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/otbtf.py b/python/otbtf.py index c180a7c4..b28a1cc4 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -671,7 +671,6 @@ class TFRecords: self.save(output_shapes, self.output_shape_file) @staticmethod - def parse_tfrecord(example, features_types, target_keys): def parse_tfrecord(example, features_types, target_keys, preprocessing_fn=None, **kwargs): """ Parse example object to sample dict. -- GitLab From 929dae88414026ac9dc760cb01d23e7d99a4a7ce Mon Sep 17 00:00:00 2001 From: Narcon Nicolas <nicolas.narcon@inrae.fr> Date: Thu, 21 Apr 2022 16:55:58 +0200 Subject: [PATCH 17/17] ENH: generate samples of same type as initial raster --- python/otbtf.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/python/otbtf.py b/python/otbtf.py index b28a1cc4..a1cf9bd4 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -58,8 +58,11 @@ def read_as_np_arr(gdal_ds, as_patches=True): False, the shape is (1, psz_y, psz_x, nb_channels) :return: Numpy array of dim 4 """ - buffer = gdal_ds.ReadAsArray() + gdal_to_np_types = {1: 'uint8', 2: 'uint16', 3: 'int16', 4: 'uint32', 5: 'int32', 6: 'float32', 7: 'float64', + 10: 'complex64', 11: 'complex128'} + gdal_type = gdal_ds.GetRasterBand(1).DataType size_x = gdal_ds.RasterXSize + buffer = gdal_ds.ReadAsArray().astype(gdal_to_np_types[gdal_type]) if len(buffer.shape) == 3: buffer = np.transpose(buffer, axes=(1, 2, 0)) if not as_patches: @@ -68,7 +71,7 @@ def read_as_np_arr(gdal_ds, as_patches=True): else: n_elems = int(gdal_ds.RasterYSize / size_x) size_y = size_x - return np.float32(buffer.reshape((n_elems, size_y, size_x, gdal_ds.RasterCount))) + return buffer.reshape((n_elems, size_y, size_x, gdal_ds.RasterCount)) # -------------------------------------------------- Buffer class ------------------------------------------------------ @@ -244,8 +247,11 @@ class PatchesImagesReader(PatchesReaderBase): @staticmethod def _read_extract_as_np_arr(gdal_ds, offset): + gdal_to_np_types = {1: 'uint8', 2: 'uint16', 3: 'int16', 4: 'uint32', 5: 'int32', 6: 'float32', 7: 'float64', + 10: 'complex64', 11: 'complex128'} assert gdal_ds is not None psz = gdal_ds.RasterXSize + gdal_type = gdal_ds.GetRasterBand(1).DataType yoff = int(offset * psz) assert yoff + psz <= gdal_ds.RasterYSize buffer = gdal_ds.ReadAsArray(0, yoff, psz, psz) @@ -254,7 +260,7 @@ class PatchesImagesReader(PatchesReaderBase): else: # single-band raster buffer = np.expand_dims(buffer, axis=2) - return np.float32(buffer) + return buffer.astype(gdal_to_np_types[gdal_type]) def get_sample(self, index): """ -- GitLab