Class PointwiseConcatTransform¶
Defined in File pointwise_layout_transform.hpp
Inheritance Relationships¶
Base Type¶
public ams::LayoutTransform(Class LayoutTransform)
Class Documentation¶
-
class PointwiseConcatTransform : public ams::LayoutTransform¶
-
Converts Inputs + InOuts into a single matrix [N, SUM(K_i)] where:
N = batch size (outer dim)
K_i = flattened size of each tensor field except the batch dimension
Supports: ✔ Scalar fields (shape [N]) ✔ Multi-channel fields (shape [N, K]) ✔ Arbitrary shapes [N, …] → flattened to [N, M] ✔ Prediction-only models ✔ Uncertainty-aware models returning (pred, uncertainty)
Produces IndexMap for both pack() and unpack().
Public Functions
-
inline virtual const char *name() const noexcept override¶
-
inline virtual AMSExpected<IndexMap> pack(const TensorBundle &Inputs, const TensorBundle &InOuts, at::Tensor &ModelInput) override¶
-
inline virtual AMSStatus unpack(const torch::jit::IValue &ModelOutput, TensorBundle &Outs, TensorBundle &InOuts, std::optional<at::Tensor> &Uncertainties) override¶