Neural Structured Learning (NSL): A TensorFlow Framework To Train Neural Networks With Structured Signals

0
1067
https://www.tensorflow.org/neural_structured_learning

Neural Structured Learning (NSL) is a TensorFlow framework for training neural networks with structured signals. NSL can handle structured input in two ways: 

(i) As an explicit graph (for Neural Graph Learning)

(ii) As an implicit graph (for Adversarial Learning)

These techniques only affect the training workflow while the model serving workflow remains unchanged. This occurs due to these techniques being implemented as a form of regularization in the framework. 

https://blog.tensorflow.org/2020/10/neural-structured-learning-in-tfx.html?m=1

Structured signals are used to represent relations among labeled/unlabeled samples. Leveraging these signals during neural network training harnesses both labeled and unlabeled data, which results in improved model accuracy (when the amount of labeled data is relatively small).

NSL generalizes Neural Graph Learning and Adversarial Learning. The NSL framework provides the following easy-to-use APIs and tools to train models with structured signals:

  • Keras APIs: Enables training with graphs and adversarial perturbations.
  • TF ops and functions: Facilitates training with the structure using lower-level TensorFlow APIs
  • Tools: Build graphs and construct graph inputs for training

The high-level workflow for building a graph-regularized model using NSL comprises the following steps:

  1. Building a graph.
  2. Augmenting the training data using the graph (and the input example features)
  3. Applying graph regularization to a given model using the augmented training data.

Although the steps above do not immediately map onto existing TFX pipeline components, TFX supports custom components that allow users to implement the custom processing within their TFX pipelines. 

Adversarial learning is another aspect of Neural Structured Learning. Implicit neighbors are created dynamically and adversarially (in adversarial learning). This is done to confuse the model instead of using explicit neighbors from a regularization graph. Regularizing using the adversarial method effectively improves the model’s robustness. Also, it is easy to integrate adversarial learning using NSL into a TFX pipeline. Custom components are not required, and the only requirement is to update the trainer component to invoke the adversarial regularization wrapper API in NSL.

The models trained with samples generated by adding adversarial perturbation are robust against malicious attacks, designed to mislead a model’s classification.

import tensorflow as tf
import neural_structured_learning as nsl

# Prepare data.
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Create a base model -- sequential, functional, or subclass.
model = tf.keras.Sequential([
    tf.keras.Input((28, 28), name='feature'),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation=tf.nn.relu),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])

# Wrap the model with adversarial regularization.
adv_config = nsl.configs.make_adv_reg_config(multiplier=0.2, adv_step_size=0.05)
adv_model = nsl.keras.AdversarialRegularization(model, adv_config=adv_config)

# Compile, train, and evaluate.
adv_model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
adv_model.fit({'feature': x_train, 'label': y_train}, batch_size=32, epochs=5)
adv_model.evaluate({'feature': x_test, 'label': y_test})

#Source of above code: https://www.tensorflow.org/neural_structured_learning

Source: https://blog.tensorflow.org/2020/10/neural-structured-learning-in-tfx.html?m=1

Resource: https://www.tensorflow.org/neural_structured_learning

Github: https://github.com/tensorflow/neural-structured-learning

LEAVE A REPLY

Please enter your comment!
Please enter your name here

This site uses Akismet to reduce spam. Learn how your comment data is processed.