Skip to content

Model Wrappers

The TruLens library is designed to support models implemented via a variety of different popular python neural network frameworks: Keras (with TensorFlow or Theano backend), TensorFlow, and Pytorch. Models developed with different frameworks implement things (e.g., gradient computations) a number of different ways. We define framework specific ModelWrapper instances to create a unified model API, providing the same functionality to models that are implemented in disparate frameworks. In order to compute attributions for a model, we provide a trulens.nn.models.get_model_wrapper function that will return an appropriate ModelWrapper instance.

Some parameters are exclusively utilized for specific frameworks and are outlined in the parameter descriptions.

get_model_wrapper(model, *, logit_layer=None, replace_softmax=False, softmax_layer=-1, custom_objects=None, device=None, input_tensors=None, output_tensors=None, internal_tensor_dict=None, default_feed_dict=None, session=None, backend=None, force_eval=True, **kwargs)

Returns a ModelWrapper implementation that exposes the components needed for computing attributions.

Parameters:

Name Type Description Default
model ModelLike

The model to wrap. If using the TensorFlow 1 backend, this is expected to be a graph object.

required
logit_layer

Supported for Keras and Pytorch models. Specifies the name or index of the layer that produces the logit predictions.

None
replace_softmax bool

Supported for Keras models only. If true, the activation function in the softmax layer (specified by softmax_layer) will be changed to a 'linear' activation.

False
softmax_layer

Supported for Keras models only. Specifies the layer that performs the softmax. This layer should have an activation attribute. Only used when replace_softmax is true.

-1
custom_objects

Optional, for use with Keras models only. A dictionary of custom objects used by the Keras model.

None
device str

Optional, for use with Pytorch models only. A string specifying the device to run the model on.

None
input_tensors

Required for use with TensorFlow 1 graph models only. A list of tensors representing the input to the model graph.

None
output_tensors

Required for use with TensorFlow 1 graph models only. A list of tensors representing the output to the model graph.

None
internal_tensor_dict

Optional, for use with TensorFlow 1 graph models only. A dictionary mapping user-selected layer names to the internal tensors in the model graph that the user would like to expose. This is provided to give more human-readable names to the layers if desired. Internal tensors can also be accessed via the name given to them by tensorflow.

None
default_feed_dict

Optional, for use with TensorFlow 1 graph models only. A dictionary of default values to give to tensors in the model graph.

None
session

Optional, for use with TensorFlow 1 graph models only. A tf.Session object to run the model graph in. If None, a new temporary session will be generated every time the model is run.

None
backend

Optional, for forcing a specific backend. String values recognized are pytorch, tensorflow, keras, or tf.keras.

None
force_eval

_Optional, True will force a model.eval() call for PyTorch models. False will retain current model state

True
Source code in trulens_explain/trulens/nn/models/__init__.py
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
def get_model_wrapper(
    model: ModelLike,
    *,
    logit_layer=None,
    replace_softmax: bool = False,
    softmax_layer=-1,
    custom_objects=None,
    device: str = None,
    input_tensors=None,
    output_tensors=None,
    internal_tensor_dict=None,
    default_feed_dict=None,
    session=None,
    backend=None,
    force_eval=True,
    **kwargs
):
    """
    Returns a ModelWrapper implementation that exposes the components needed for computing attributions.

    Parameters:
        model:
            The model to wrap. If using the TensorFlow 1 backend, this is 
            expected to be a graph object.

        logit_layer:
            _Supported for Keras and Pytorch models._ 
            Specifies the name or index of the layer that produces the
            logit predictions. 

        replace_softmax:
            _Supported for Keras models only._ If true, the activation
            function in the softmax layer (specified by `softmax_layer`) 
            will be changed to a `'linear'` activation. 

        softmax_layer:
            _Supported for Keras models only._ Specifies the layer that
            performs the softmax. This layer should have an `activation`
            attribute. Only used when `replace_softmax` is true.

        custom_objects:
            _Optional, for use with Keras models only._ A dictionary of
            custom objects used by the Keras model.

        device:
            _Optional, for use with Pytorch models only._ A string
            specifying the device to run the model on.

        input_tensors:
            _Required for use with TensorFlow 1 graph models only._ A list
            of tensors representing the input to the model graph.

        output_tensors:
            _Required for use with TensorFlow 1 graph models only._ A list
            of tensors representing the output to the model graph.

        internal_tensor_dict:
            _Optional, for use with TensorFlow 1 graph models only._ A
            dictionary mapping user-selected layer names to the internal
            tensors in the model graph that the user would like to expose.
            This is provided to give more human-readable names to the layers
            if desired. Internal tensors can also be accessed via the name
            given to them by tensorflow.

        default_feed_dict:
            _Optional, for use with TensorFlow 1 graph models only._ A
            dictionary of default values to give to tensors in the model
            graph.

        session:
            _Optional, for use with TensorFlow 1 graph models only._ A 
            `tf.Session` object to run the model graph in. If `None`, a new
            temporary session will be generated every time the model is run.

        backend:
            _Optional, for forcing a specific backend._ String values recognized
            are pytorch, tensorflow, keras, or tf.keras.

        force_eval:
            _Optional, True will force a model.eval() call for PyTorch models. False
            will retain current model state

    Returns: ModelWrapper
    """

    if 'input_shape' in kwargs:
        tru_logger.deprecate(
            f"get_model_wrapper: input_shape parameter is no longer used and will be removed in the future"
        )
        del kwargs['input_shape']
    if 'input_dtype' in kwargs:
        tru_logger.deprecate(
            f"get_model_wrapper: input_dtype parameter is no longer used and will be removed in the future"
        )
        del kwargs['input_dtype']

    # get existing backend
    B = get_backend(suppress_warnings=True)

    if backend is None:
        backend = discern_backend(model)
        tru_logger.info(
            "Detected {} backend for {}.".format(
                backend.name.lower(), type(model)
            )
        )
    else:
        backend = Backend.from_name(backend)
    if B is None or (backend is not Backend.UNKNOWN and B.backend != backend):
        tru_logger.info(
            "Changing backend from {} to {}.".format(
                None if B is None else B.backend, backend
            )
        )
        os.environ['TRULENS_BACKEND'] = backend.name.lower()
        B = get_backend()
    else:
        tru_logger.info("Using backend {}.".format(B.backend))
    tru_logger.info(
        "If this seems incorrect, you can force the correct backend by passing the `backend` parameter directly into your get_model_wrapper call."
    )
    if B.backend.is_keras_derivative():
        from trulens.nn.models.keras import KerasModelWrapper
        return KerasModelWrapper(
            model,
            logit_layer=logit_layer,
            replace_softmax=replace_softmax,
            softmax_layer=softmax_layer,
            custom_objects=custom_objects
        )

    elif B.backend == Backend.PYTORCH:
        from trulens.nn.models.pytorch import PytorchModelWrapper
        return PytorchModelWrapper(
            model,
            logit_layer=logit_layer,
            device=device,
            force_eval=force_eval
        )
    elif B.backend == Backend.TENSORFLOW:
        import tensorflow as tf
        if tf.__version__.startswith('2'):
            from trulens.nn.models.tensorflow_v2 import Tensorflow2ModelWrapper
            return Tensorflow2ModelWrapper(
                model,
                logit_layer=logit_layer,
                replace_softmax=replace_softmax,
                softmax_layer=softmax_layer,
                custom_objects=custom_objects
            )
        else:
            from trulens.nn.models.tensorflow_v1 import TensorflowModelWrapper
            if input_tensors is None:
                tru_logger.error(
                    'tensorflow1 model must pass parameter: input_tensors'
                )
            if output_tensors is None:
                tru_logger.error(
                    'tensorflow1 model must pass parameter: output_tensors'
                )
            return TensorflowModelWrapper(
                model,
                input_tensors=input_tensors,
                output_tensors=output_tensors,
                internal_tensor_dict=internal_tensor_dict,
                session=session
            )