Class PointwiseConcatTransform

Inheritance Relationships

Base Type

Class Documentation

class PointwiseConcatTransform : public ams::LayoutTransform

PointwiseConcatTransform:

Converts Inputs + InOuts into a single matrix [N, SUM(K_i)] where:

  • N = batch size (outer dim)

  • K_i = flattened size of each tensor field except the batch dimension

Supports: ✔ Scalar fields (shape [N]) ✔ Multi-channel fields (shape [N, K]) ✔ Arbitrary shapes [N, …] → flattened to [N, M] ✔ Prediction-only models ✔ Uncertainty-aware models returning (pred, uncertainty)

Produces IndexMap for both pack() and unpack().

Public Functions

inline virtual const char *name() const noexcept override
inline virtual AMSExpected<IndexMap> pack(const TensorBundle &Inputs, const TensorBundle &InOuts, at::Tensor &ModelInput) override
inline virtual AMSStatus unpack(const torch::jit::IValue &ModelOutput, TensorBundle &Outs, TensorBundle &InOuts, std::optional<at::Tensor> &Uncertainties) override