model.image_classification package
Submodules
model.image_classification.image_classifier module
- class model.image_classification.image_classifier.ImageClassifier(name: str, model, inputs_spec: ImageSpec, preadapter: Callable[[...], Tuple[Tuple[Any, ...], Dict[str, Any]]] | None = None, postadapter: Callable[[Any], Any] | None = None)
Bases:
ArmoryModel,ModelProtocolWrapper around a model that produces image classification predictions.
This wrapper automatically applies a postadapter to the models outputs to extract the logits, probs or scores attribute from the returned object if any such attribute is found. Otherwise the unmodified model output is returned.
Example:
import armory.data from armory.model.image_classification import ImageClassifier # assuming `model` has been defined elsewhere classifier = ImageClassifier( name="My model", model=model, inputs_spec=armory.data.TorchImageSpec( dim=armory.data.ImageDimensions.CHW, scale=armory.data.Scale( dtype=armory.data.DataType.FLOAT, max=1.0, ), ), )
- loss(batch: ImageClassificationBatch)
Calculates the loss for the given batch
- predict(batch: ImageClassificationBatch)
Invokes the wrapped model using the image inputs in the given batch and updates the image classification predictions in the batch.
- Parameters:
batch (ImageClassificationBatch) – Image classification batch
- training: bool
Module contents
This package contains model wrappers for image classification.