Class SurrogateModel

Class Documentation

class SurrogateModel

Public Functions

inline ~SurrogateModel()
std::tuple<torch::Tensor, torch::Tensor> _evaluate(torch::Tensor &inputs, const float threshold)
std::tuple<torch::Tensor, torch::Tensor> evaluate(ams::MutableArrayRef<at::Tensor> Inputs, const float threshold)
inline bool is_gpu() const
inline bool is_cpu() const
inline bool is_resource(ams::AMSResourceType rType) const
inline bool is_float() const
inline bool is_double() const
inline bool is_type(ams::AMSDType dType) const
std::tuple<ams::AMSResourceType, torch::DeviceType> convertModelResourceType(std::string &device)
std::tuple<ams::AMSDType, torch::Dtype> convertModelDataType(std::string &type)
std::tuple<ams::AMSResourceType, torch::DeviceType> getModelResourceType() const
std::tuple<ams::AMSDType, torch::Dtype> getModelDataType() const

Public Static Functions

static inline std::shared_ptr<SurrogateModel> getInstance(std::string &model_path)

Protected Functions

SurrogateModel(std::string &model_path)

Protected Static Attributes

static std::unordered_map<std::string, std::shared_ptr<SurrogateModel>> instances