Class BaseModel¶
Defined in File Model.hpp
Nested Relationships¶
Nested Types¶
Inheritance Relationships¶
Derived Type¶
public ams::ml::InferenceModel(Class InferenceModel)
Class Documentation¶
-
class BaseModel¶
Owning handle around a TorchScript module with move-only semantics.
BaseModel owns a
torch::jit::Moduletogether with its device and dtype. The underlying Torch module has reference-like semantics: copying atorch::jit::Moduledoes not clone the model, but creates another handle to the same internal object.Because BaseModel provides mutating operations (for example convertTo()), allowing it to be copyable would introduce surprising aliasing:
BaseModel A = ...; BaseModel B = A; // B aliases the same JIT module as A B.convertTo<float>(); // mutates both A and B
To avoid this class of bugs, BaseModel is intentionally non-copyable and move-only. Ownership can be transferred (returned from factories, stored in containers), but accidental copies are disallowed, similar to std::unique_ptr.
Subclassed by ams::ml::InferenceModel
Public Types
-
using DType = torch::ScalarType¶
Scalar type used by the model weights.
-
using DeviceType = torch::Device¶
Device type on which the model resides.
-
using HashT = uint64_t¶
An integer identifier for the model.
Public Functions
-
inline bool isDevice() const¶
Return true if the model is resident on a device (e.g. GPU).
This is typically equivalent to checking whether the underlying device is a CUDA or HIP device, as opposed to a CPU device.
-
inline torch::Device getDevice() const¶
Return the Torch device on which the model currently lives.
-
template<typename ScalarT>
inline bool isType() const¶ Return true if the model dtype matches the given scalar type.
Example:
if (Model.isType<double>()) { ... }
-
template<typename ScalarT>
inline AMSStatus convertTo(std::optional<torch::Device> TargetDevice = std::nullopt)¶ convert the model to the requested device and dtype.
This is a combined convenience operation that migrates the model to a target device and scalar type in one step. The function modifies the current object
Public Static Functions
-
static AMSExpected<std::unique_ptr<BaseModel>> load(const AbstractModel &Descriptor)¶
Load a model from an AbstractModel descriptor using the on-disk dtype and device encoded in the model file.
On success, returns a fully constructed BaseModel instance. On error, returns an AMSError describing the failure.
Protected Functions
-
BaseModel(const AbstractModel &AModel)¶
Construct a BaseModel from an existing Torch module, device, and dtype.
This is primarily intended for use by factory functions (e.g. load()) and subclasses such as InferenceModel.
-
inline torch::jit::Module &getJITModel()¶
Mutable access to the underlying Torch module.
Intended for subclasses that need to configure the module (e.g. set eval mode, attach buffers, etc.).
-
inline const torch::jit::Module &getJITModel() const¶
Const access to the underlying Torch module.
-
inline void setDevice(torch::Device Device)¶
Set the device of the model (does not move/copy it to that device)
-
using DType = torch::ScalarType¶