Skip to content

Add helpers for common preprocessing functions

e.g.

tf.one_hot(
    tf.squeeze(tf.cast(sample["labels_patches"], tf.int32), axis=-1),
    depth=nb_cls
)
class DilatedMask(keras.layers.Layer):
    def __init__(self, nodata_value, dilatation_radius, name=None):
        self.nodata_value = nodata_value
        self.dilatation_radius = dilatation_radius
        super().__init__(name=name)

    def call(self, inp):
        """
        :param inp: input layer
        """
        # Compute a binary mask from the input
        nodata_mask = tf.cast(tf.math.equal(inp, self.nodata_value), tf.uint8)

        dilatation_size = 1 + 2 * self.dilatation_radius
        # Create a morphological kernel suitable for binary dilatation, cf https://stackoverflow.com/q/54686895/13711499
        kernel = tf.zeros((dilatation_size, dilatation_size, 1), dtype=tf.uint8)
        return tf.cast(tf.nn.dilation2d(input=nodata_mask, filters=kernel, strides=[1, 1, 1, 1], padding="SAME",
                                data_format="NHWC", dilations=[1, 1, 1, 1], name="dilatation_tf"), tf.uint8)


class ApplyMask(keras.layers.Layer):
    def __init__(self, out_nodata, name=None):
        super().__init__(name=name)
        self.out_nodata = out_nodata

    def call(self, inputs):
        """
        :param inputs: [mask, input]. Mask is a binary mask, where 1 indicate the values to be masked on the input.
        """
        mask, inp = inputs
        return tf.where(mask == 1, float(self.out_nodata), inp)


class ScalarsTile(keras.layers.Layer):
    """
    Duplicate some scalars in an whole array.
    Simple example with only one scalar = 0.152: output [[0.152, 0.152, 0.152],
                                                         [0.152, 0.152, 0.152],
                                                         [0.152, 0.152, 0.152]]
    """
    def __init__(self, name=None):
        super().__init__(name=name)

    def call(self, inputs):
        """
        :param inputs: [reference, scalar inputs]. Reference is the tensor whose shape has to be matched
        """
        ref, scalar_inputs = inputs
        inp = tf.stack(scalar_inputs, axis=-1)
        inp = tf.expand_dims(tf.expand_dims(inp, axis=1), axis=1)
        return tf.tile(inp, [1, tf.shape(ref)[1], tf.shape(ref)[2], 1])

class Argmax(keras.layers.Layer):
    """
    Compute the argmax of a tensor. For example, for a vector A=[0.1, 0.3, 0.6], the output is 2 (A[2] is the max)
    Useful to transform a probability multibands map into a categorical map
    """
    def __init__(self, name=None):
        super().__init__(name=name)

    def call(self, inputs):
        return tf.expand_dims(tf.math.argmax(inputs, axis=-1), axis=-1)


class Max(keras.layers.Layer):
    """
    Compute the max of a tensor. For example, for a vector [0.1, 0.3, 0.6], the output is 0.6
    Useful to transform a probability multibands map into a "confidence" map
    """
    def __init__(self, name=None):
        super().__init__(name=name)

    def call(self, inputs):
        return tf.expand_dims(tf.math.reduce_max(inputs, axis=-1), axis=-1)
Edited by Cresson Remi