model.image_classification package

Submodules

model.image_classification.image_classifier module

class model.image_classification.image_classifier.ImageClassifier(name: str, model, inputs_spec: ImageSpec, preadapter: Callable[[...], Tuple[Tuple[Any, ...], Dict[str, Any]]] | None = None, postadapter: Callable[[Any], Any] | None = None)

Bases: ArmoryModel, ModelProtocol

Wrapper around a model that produces image classification predictions.

This wrapper automatically applies a postadapter to the models outputs to extract the logits, probs or scores attribute from the returned object if any such attribute is found. Otherwise the unmodified model output is returned.

Example:

import armory.data
from armory.model.image_classification import ImageClassifier

# assuming `model` has been defined elsewhere
classifier = ImageClassifier(
    name="My model",
    model=model,
    inputs_spec=armory.data.TorchImageSpec(
        dim=armory.data.ImageDimensions.CHW,
        scale=armory.data.Scale(
            dtype=armory.data.DataType.FLOAT,
            max=1.0,
        ),
    ),
)
loss(batch: ImageClassificationBatch)

Calculates the loss for the given batch

predict(batch: ImageClassificationBatch)

Invokes the wrapped model using the image inputs in the given batch and updates the image classification predictions in the batch.

Parameters:

batch (ImageClassificationBatch) – Image classification batch

training: bool

Module contents

This package contains model wrappers for image classification.