Class BaseModel

Nested Relationships

Nested Types

Inheritance Relationships

Derived Type

Class Documentation

class BaseModel

Owning handle around a TorchScript module with move-only semantics.

BaseModel owns a torch::jit::Module together with its device and dtype. The underlying Torch module has reference-like semantics: copying a torch::jit::Module does not clone the model, but creates another handle to the same internal object.

Because BaseModel provides mutating operations (for example convertTo()), allowing it to be copyable would introduce surprising aliasing:

BaseModel A = ...;
BaseModel B = A;          // B aliases the same JIT module as A
B.convertTo<float>();     // mutates both A and B

To avoid this class of bugs, BaseModel is intentionally non-copyable and move-only. Ownership can be transferred (returned from factories, stored in containers), but accidental copies are disallowed, similar to std::unique_ptr.

Subclassed by ams::ml::InferenceModel

Public Types

using DType = torch::ScalarType

Scalar type used by the model weights.

using DeviceType = torch::Device

Device type on which the model resides.

using HashT = uint64_t

An integer identifier for the model.

Public Functions

inline bool isDevice() const

Return true if the model is resident on a device (e.g. GPU).

This is typically equivalent to checking whether the underlying device is a CUDA or HIP device, as opposed to a CPU device.

inline torch::Device getDevice() const

Return the Torch device on which the model currently lives.

inline DType getDType() const

Return the scalar dtype of the model weights.

template<typename ScalarT>
inline bool isType() const

Return true if the model dtype matches the given scalar type.

Example:

if (Model.isType<double>()) { ... }

template<typename ScalarT>
inline AMSStatus convertTo(std::optional<torch::Device> TargetDevice = std::nullopt)

convert the model to the requested device and dtype.

This is a combined convenience operation that migrates the model to a target device and scalar type in one step. The function modifies the current object

BaseModel(const BaseModel&) = delete
BaseModel &operator=(const BaseModel&) = delete
BaseModel(BaseModel&&) = default
BaseModel &operator=(BaseModel&&) = default

Public Static Functions

static AMSExpected<std::unique_ptr<BaseModel>> load(const AbstractModel &Descriptor)

Load a model from an AbstractModel descriptor using the on-disk dtype and device encoded in the model file.

On success, returns a fully constructed BaseModel instance. On error, returns an AMSError describing the failure.

Protected Functions

BaseModel(const AbstractModel &AModel)

Construct a BaseModel from an existing Torch module, device, and dtype.

This is primarily intended for use by factory functions (e.g. load()) and subclasses such as InferenceModel.

inline torch::jit::Module &getJITModel()

Mutable access to the underlying Torch module.

Intended for subclasses that need to configure the module (e.g. set eval mode, attach buffers, etc.).

inline const torch::jit::Module &getJITModel() const

Const access to the underlying Torch module.

inline void setDevice(torch::Device Device)

Set the device of the model (does not move/copy it to that device)

inline void setDType(DType ModelDType)

Set the device of the model (does not move/copy it to that device)