TensorMCMC provides low-rank tensor regression for tensor predictors and scalar covariates. The package uses simple stochastic updates (inspired by MCMC) and includes fast C++ acceleration for coefficient updates and predictions.This vignette demonstrates how to fit a tensor regression model, make predictions, and evaluate performance using cross-validation.
install.packages(“devtools”) devtools::install_github(“Ritwick2012/TensorMCMC”)
set.seed(2026)
n <- 100 # number of observations
p <- 7 # first tensor dimension
d <- 5 # second tensor dimension
pgamma <- 2 # number of scalar covariates
x <- array(rnorm(n*p*d), dim = c(n,p,d)) #Tensor predictor array
z <- matrix(rnorm(n*pgamma), n, pgamma) #Scalar covariates
y <- rnorm(n) #Response
## Fitting Tensor Regression
fit <- tensor.reg(z, x, y, nsweep = 10, rank = 2)
fit
#> $beta.store
#> , , 1, 1
#>
#> [,1] [,2]
#> [1,] 0.6175178 0.5061472
#> [2,] 0.6128395 0.4613790
#> [3,] 0.6784481 0.3833239
#> [4,] 0.6860400 0.3874586
#> [5,] 0.8250265 0.3895816
#> [6,] 0.8373070 0.4136012
#> [7,] 0.8680680 0.5201791
#> [8,] 0.8007524 0.4354310
#> [9,] 0.8576530 0.4482342
#> [10,] 0.8612398 0.4684577
#>
#> , , 2, 1
#>
#> [,1] [,2]
#> [1,] 0.3975854 -0.7112143
#> [2,] 0.4524064 -0.7612774
#> [3,] 0.4068241 -0.7869248
#> [4,] 0.4271915 -0.7140589
#> [5,] 0.4452346 -0.7760923
#> [6,] 0.4219443 -0.7431099
#> [7,] 0.4701565 -0.7364423
#> [8,] 0.3999119 -0.8309820
#> [9,] 0.4333260 -0.7534764
#> [10,] 0.4454967 -0.7035353
#>
#> , , 3, 1
#>
#> [,1] [,2]
#> [1,] 0.5241160 0.8549870
#> [2,] 0.5069353 0.8692748
#> [3,] 0.4307115 0.9216209
#> [4,] 0.4180394 0.9036876
#> [5,] 0.3715078 0.9167656
#> [6,] 0.3466591 0.9865087
#> [7,] 0.3678846 0.9376166
#> [8,] 0.3445136 1.0112642
#> [9,] 0.2296820 1.0526901
#> [10,] 0.2664259 1.0176780
#>
#> , , 4, 1
#>
#> [,1] [,2]
#> [1,] -0.6825909 1.336783
#> [2,] -0.7030203 1.375442
#> [3,] -0.6352342 1.445072
#> [4,] -0.6030446 1.402250
#> [5,] -0.5922969 1.408574
#> [6,] -0.5767963 1.340538
#> [7,] -0.5842332 1.355406
#> [8,] -0.4912046 1.349345
#> [9,] -0.5511322 1.348177
#> [10,] -0.6217880 1.410002
#>
#> , , 5, 1
#>
#> [,1] [,2]
#> [1,] 0.4411591 -0.5538530
#> [2,] 0.4793234 -0.5697876
#> [3,] 0.3756414 -0.6338545
#> [4,] 0.4062431 -0.6303656
#> [5,] 0.3650854 -0.6224204
#> [6,] 0.4757318 -0.6048942
#> [7,] 0.5650123 -0.5794329
#> [8,] 0.5779342 -0.5647220
#> [9,] 0.5967613 -0.5749377
#> [10,] 0.5641120 -0.6093751
#>
#> , , 6, 1
#>
#> [,1] [,2]
#> [1,] -0.40522550 -0.0381084804
#> [2,] -0.45783346 0.0001840742
#> [3,] -0.44177185 -0.0756921244
#> [4,] -0.45725632 -0.0412047830
#> [5,] -0.36029857 -0.0352665653
#> [6,] -0.30226926 -0.1565793410
#> [7,] -0.19623602 -0.1824652891
#> [8,] -0.12042432 -0.2572382854
#> [9,] -0.07943167 -0.2385513959
#> [10,] -0.04975199 -0.2603805731
#>
#> , , 7, 1
#>
#> [,1] [,2]
#> [1,] 0.8347532 0.4227896
#> [2,] 0.8049088 0.3560675
#> [3,] 0.8064333 0.2958949
#> [4,] 0.7883311 0.3060742
#> [5,] 0.8306198 0.3427905
#> [6,] 0.7546770 0.3127194
#> [7,] 0.6590266 0.3218253
#> [8,] 0.6624119 0.3380433
#> [9,] 0.6484295 0.3104538
#> [10,] 0.5794229 0.2957562
#>
#> , , 1, 2
#>
#> [,1] [,2]
#> [1,] 1.560708 0.5129428
#> [2,] 1.668662 0.4854308
#> [3,] 1.628236 0.5120526
#> [4,] 1.605402 0.5505858
#> [5,] 1.577144 0.5427390
#> [6,] 1.508130 0.5617251
#> [7,] 1.572179 0.6077413
#> [8,] 1.541252 0.5855361
#> [9,] 1.540316 0.6206302
#> [10,] 1.630110 0.5329874
#>
#> , , 2, 2
#>
#> [,1] [,2]
#> [1,] -1.371256 -0.8709372
#> [2,] -1.380835 -0.8193080
#> [3,] -1.314125 -0.7908527
#> [4,] -1.309169 -0.7440521
#> [5,] -1.198581 -0.6887966
#> [6,] -1.195897 -0.6477300
#> [7,] -1.153860 -0.6271236
#> [8,] -1.141386 -0.6390028
#> [9,] -1.214423 -0.6491257
#> [10,] -1.174451 -0.5826772
#>
#> , , 3, 2
#>
#> [,1] [,2]
#> [1,] 1.554911 0.5309417
#> [2,] 1.463805 0.5149889
#> [3,] 1.423261 0.5620356
#> [4,] 1.462391 0.5395002
#> [5,] 1.524665 0.5398589
#> [6,] 1.569172 0.4644514
#> [7,] 1.491634 0.4498306
#> [8,] 1.497423 0.3831688
#> [9,] 1.508404 0.3515085
#> [10,] 1.534362 0.2890586
#>
#> , , 4, 2
#>
#> [,1] [,2]
#> [1,] 1.451390 0.4170323
#> [2,] 1.498053 0.3984593
#> [3,] 1.577584 0.4489993
#> [4,] 1.547421 0.4950810
#> [5,] 1.492551 0.4945987
#> [6,] 1.625103 0.4618810
#> [7,] 1.611064 0.5398709
#> [8,] 1.615667 0.5461319
#> [9,] 1.713706 0.5830314
#> [10,] 1.711902 0.6266945
#>
#> , , 5, 2
#>
#> [,1] [,2]
#> [1,] -0.6664497 -0.5522290
#> [2,] -0.6192354 -0.5776578
#> [3,] -0.6926324 -0.5648589
#> [4,] -0.7285844 -0.5799299
#> [5,] -0.6795743 -0.5134383
#> [6,] -0.6604034 -0.4695106
#> [7,] -0.6775142 -0.4373082
#> [8,] -0.6564486 -0.4097951
#> [9,] -0.6248954 -0.4860086
#> [10,] -0.6995958 -0.4458456
#>
#> , , 6, 2
#>
#> [,1] [,2]
#> [1,] -0.18560175 -1.1265199
#> [2,] -0.26733785 -1.0840325
#> [3,] -0.22326306 -1.0042501
#> [4,] -0.25946027 -0.9861894
#> [5,] -0.26324539 -1.0297807
#> [6,] -0.23786733 -0.9352913
#> [7,] -0.16806827 -0.9967152
#> [8,] -0.16920847 -1.0356691
#> [9,] -0.12999809 -1.0050305
#> [10,] -0.02949333 -1.0121170
#>
#> , , 7, 2
#>
#> [,1] [,2]
#> [1,] 0.9844980 -0.8321533
#> [2,] 1.0141669 -0.8183583
#> [3,] 0.9625310 -0.8044637
#> [4,] 0.9562452 -0.7783111
#> [5,] 0.9928112 -0.8703799
#> [6,] 0.9917019 -0.9137084
#> [7,] 0.9537772 -0.9119563
#> [8,] 0.9120949 -0.9096190
#> [9,] 0.8514157 -0.9760316
#> [10,] 0.7412234 -0.9850303
#>
#> , , 1, 3
#>
#> [,1] [,2]
#> [1,] 1.189347 1.549281
#> [2,] 1.233500 1.583087
#> [3,] 1.279307 1.609604
#> [4,] 1.352898 1.544066
#> [5,] 1.276144 1.534223
#> [6,] 1.313418 1.480760
#> [7,] 1.373150 1.455964
#> [8,] 1.429736 1.401691
#> [9,] 1.469976 1.325374
#> [10,] 1.479527 1.312298
#>
#> , , 2, 3
#>
#> [,1] [,2]
#> [1,] 0.5858142 -0.5659815
#> [2,] 0.5426417 -0.6136374
#> [3,] 0.5209279 -0.5885290
#> [4,] 0.4476204 -0.5598592
#> [5,] 0.4659866 -0.5706244
#> [6,] 0.4549289 -0.5636388
#> [7,] 0.5260105 -0.5840377
#> [8,] 0.4709779 -0.5852145
#> [9,] 0.4180531 -0.5880248
#> [10,] 0.4179754 -0.6452659
#>
#> , , 3, 3
#>
#> [,1] [,2]
#> [1,] -2.766364 -0.04904548
#> [2,] -2.775928 -0.03338237
#> [3,] -2.745984 -0.02795111
#> [4,] -2.770318 -0.03173546
#> [5,] -2.786845 -0.08899732
#> [6,] -2.794924 -0.07856328
#> [7,] -2.690536 -0.04555793
#> [8,] -2.716557 -0.06367433
#> [9,] -2.750640 -0.04160384
#> [10,] -2.831449 -0.04843835
#>
#> , , 4, 3
#>
#> [,1] [,2]
#> [1,] 0.6498014 0.9280571
#> [2,] 0.5370901 0.8398339
#> [3,] 0.5836968 0.8274056
#> [4,] 0.6090371 0.8108566
#> [5,] 0.5521506 0.8252723
#> [6,] 0.5145879 0.8631900
#> [7,] 0.5447565 0.9091312
#> [8,] 0.5402092 0.8735775
#> [9,] 0.5418392 0.9116803
#> [10,] 0.5050690 0.8950233
#>
#> , , 5, 3
#>
#> [,1] [,2]
#> [1,] 1.285322 1.361976
#> [2,] 1.293487 1.438172
#> [3,] 1.298185 1.435651
#> [4,] 1.248741 1.436864
#> [5,] 1.148789 1.503019
#> [6,] 1.158148 1.532379
#> [7,] 1.118180 1.547905
#> [8,] 1.194805 1.571523
#> [9,] 1.187147 1.621224
#> [10,] 1.123620 1.694466
#>
#> , , 6, 3
#>
#> [,1] [,2]
#> [1,] 1.261212 0.3738436
#> [2,] 1.242313 0.4179689
#> [3,] 1.238785 0.4143482
#> [4,] 1.147486 0.4748957
#> [5,] 1.166239 0.4922945
#> [6,] 1.230612 0.4981113
#> [7,] 1.191155 0.4746349
#> [8,] 1.159832 0.4631773
#> [9,] 1.149133 0.5108106
#> [10,] 1.153399 0.4808392
#>
#> , , 7, 3
#>
#> [,1] [,2]
#> [1,] 1.263711 -0.9169191
#> [2,] 1.199462 -0.9075602
#> [3,] 1.150362 -0.9562636
#> [4,] 1.026041 -0.9373388
#> [5,] 1.032376 -0.9625337
#> [6,] 1.051735 -0.9289627
#> [7,] 1.103608 -0.9186450
#> [8,] 1.092824 -0.8641071
#> [9,] 1.205468 -0.8943685
#> [10,] 1.202552 -0.8641974
#>
#> , , 1, 4
#>
#> [,1] [,2]
#> [1,] 0.11128971 0.5521929
#> [2,] 0.06903301 0.5955725
#> [3,] 0.09899319 0.6272420
#> [4,] 0.11564962 0.6084482
#> [5,] 0.08302681 0.5391810
#> [6,] 0.09710597 0.5568605
#> [7,] 0.06141057 0.5929310
#> [8,] 0.04531466 0.5931448
#> [9,] 0.07301844 0.6005005
#> [10,] 0.03915599 0.5706949
#>
#> , , 2, 4
#>
#> [,1] [,2]
#> [1,] 0.5927077 -0.2483630
#> [2,] 0.6179302 -0.2762910
#> [3,] 0.6239297 -0.3384486
#> [4,] 0.6626618 -0.2964076
#> [5,] 0.5362938 -0.2873118
#> [6,] 0.6663809 -0.2678621
#> [7,] 0.7025902 -0.2678690
#> [8,] 0.7655182 -0.3493141
#> [9,] 0.7963968 -0.2955797
#> [10,] 0.7981744 -0.2723387
#>
#> , , 3, 4
#>
#> [,1] [,2]
#> [1,] 0.8916531 0.2334595151
#> [2,] 0.8831329 0.2187412378
#> [3,] 0.8420721 0.1200824029
#> [4,] 0.8721117 0.1779409522
#> [5,] 0.8718654 0.1672570386
#> [6,] 0.9255689 0.1893203788
#> [7,] 0.8921269 0.0873405673
#> [8,] 0.9666169 0.1021759850
#> [9,] 0.9007138 0.0910585078
#> [10,] 0.9593646 -0.0009226787
#>
#> , , 4, 4
#>
#> [,1] [,2]
#> [1,] -0.5432683 1.850503
#> [2,] -0.5458781 1.849160
#> [3,] -0.5302939 1.863180
#> [4,] -0.5683311 1.864004
#> [5,] -0.6239891 1.868322
#> [6,] -0.6102463 1.745608
#> [7,] -0.5703773 1.844175
#> [8,] -0.5446780 1.881386
#> [9,] -0.4587596 1.854574
#> [10,] -0.4968003 1.868080
#>
#> , , 5, 4
#>
#> [,1] [,2]
#> [1,] -0.4174239 -0.2776256
#> [2,] -0.3892290 -0.2558838
#> [3,] -0.3362368 -0.2532731
#> [4,] -0.3661752 -0.2472783
#> [5,] -0.4074676 -0.2387563
#> [6,] -0.3942243 -0.2501432
#> [7,] -0.4335010 -0.2834402
#> [8,] -0.3926171 -0.2294251
#> [9,] -0.4721143 -0.1550824
#> [10,] -0.4720577 -0.1376031
#>
#> , , 6, 4
#>
#> [,1] [,2]
#> [1,] 0.8435491 -0.09216073
#> [2,] 0.7978925 -0.05712930
#> [3,] 0.7994098 -0.17193085
#> [4,] 0.7078550 -0.22253593
#> [5,] 0.6871246 -0.12425080
#> [6,] 0.6730658 -0.19033237
#> [7,] 0.6829471 -0.27186800
#> [8,] 0.6247745 -0.26500624
#> [9,] 0.6533393 -0.32546826
#> [10,] 0.6502446 -0.18738662
#>
#> , , 7, 4
#>
#> [,1] [,2]
#> [1,] -0.6408948 2.114525
#> [2,] -0.5912416 2.182516
#> [3,] -0.6068960 2.146908
#> [4,] -0.6128581 2.140197
#> [5,] -0.5326529 2.120325
#> [6,] -0.4158944 2.119690
#> [7,] -0.3179242 2.136305
#> [8,] -0.4147551 2.140845
#> [9,] -0.4547056 2.094674
#> [10,] -0.5094077 2.179198
#>
#> , , 1, 5
#>
#> [,1] [,2]
#> [1,] 0.7062451 -0.7723304
#> [2,] 0.6196864 -0.7536531
#> [3,] 0.5829163 -0.7512756
#> [4,] 0.5723702 -0.7895989
#> [5,] 0.5480475 -0.7942639
#> [6,] 0.5745074 -0.7906499
#> [7,] 0.6316668 -0.8471935
#> [8,] 0.5977902 -0.8662415
#> [9,] 0.6093851 -0.7695625
#> [10,] 0.5319518 -0.8286601
#>
#> , , 2, 5
#>
#> [,1] [,2]
#> [1,] -0.4647624 0.3354471
#> [2,] -0.5554142 0.3508058
#> [3,] -0.5462820 0.3627828
#> [4,] -0.5099343 0.4127992
#> [5,] -0.5889191 0.5201830
#> [6,] -0.5507445 0.5051879
#> [7,] -0.4826550 0.4411052
#> [8,] -0.4517509 0.4181059
#> [9,] -0.4697075 0.4184517
#> [10,] -0.4203216 0.4538697
#>
#> , , 3, 5
#>
#> [,1] [,2]
#> [1,] 0.3044868 -0.5891886
#> [2,] 0.3196015 -0.4866501
#> [3,] 0.3187785 -0.4407762
#> [4,] 0.3168405 -0.4297762
#> [5,] 0.2795286 -0.3597386
#> [6,] 0.3343109 -0.2855225
#> [7,] 0.3202693 -0.3046614
#> [8,] 0.3974828 -0.2362633
#> [9,] 0.3357710 -0.1939219
#> [10,] 0.3446637 -0.1584104
#>
#> , , 4, 5
#>
#> [,1] [,2]
#> [1,] 0.3305792 -1.758757
#> [2,] 0.3219762 -1.750041
#> [3,] 0.3565360 -1.783679
#> [4,] 0.2974778 -1.804834
#> [5,] 0.2718297 -1.905340
#> [6,] 0.3097949 -1.909156
#> [7,] 0.2853263 -1.929792
#> [8,] 0.2824352 -2.022980
#> [9,] 0.2556396 -2.023499
#> [10,] 0.3017312 -1.959575
#>
#> , , 5, 5
#>
#> [,1] [,2]
#> [1,] -0.2633088 -0.9496400
#> [2,] -0.2476802 -0.9167193
#> [3,] -0.2538832 -0.9065747
#> [4,] -0.2543031 -0.9318421
#> [5,] -0.3509086 -0.8779900
#> [6,] -0.3788831 -0.8625721
#> [7,] -0.4523564 -0.9500570
#> [8,] -0.3936900 -1.0716669
#> [9,] -0.3521307 -1.0875145
#> [10,] -0.3686755 -1.0629621
#>
#> , , 6, 5
#>
#> [,1] [,2]
#> [1,] -0.4835981 0.4442324
#> [2,] -0.3927571 0.4616747
#> [3,] -0.3096913 0.3422053
#> [4,] -0.3274816 0.2984537
#> [5,] -0.3681045 0.3694834
#> [6,] -0.3249294 0.4494309
#> [7,] -0.3099106 0.4368485
#> [8,] -0.3142109 0.4563319
#> [9,] -0.2969563 0.5129710
#> [10,] -0.3251497 0.5521248
#>
#> , , 7, 5
#>
#> [,1] [,2]
#> [1,] 1.614183 0.3410591
#> [2,] 1.568054 0.3156264
#> [3,] 1.608146 0.2918482
#> [4,] 1.528731 0.3051028
#> [5,] 1.498396 0.3108351
#> [6,] 1.439063 0.2737441
#> [7,] 1.491165 0.2446965
#> [8,] 1.489875 0.2994359
#> [9,] 1.414947 0.2267824
#> [10,] 1.395807 0.2377329
#>
#>
#> $gam.store
#> [,1] [,2]
#> [1,] 0.009337902 -0.003779198
#> [2,] 0.005613655 -0.028679387
#> [3,] 0.011479618 -0.037633718
#> [4,] 0.010915989 -0.036573155
#> [5,] 0.020891768 -0.037841989
#> [6,] 0.019939775 -0.038617585
#> [7,] 0.006675089 -0.034438883
#> [8,] 0.006593578 -0.038430838
#> [9,] 0.010870810 -0.046452728
#> [10,] -0.005625362 -0.058371163
#>
#> $rank
#> [1] 2
#>
#> $p
#> [1] 7
#>
#> $d
#> [1] 5
#>
#> $my
#> [1] 0.09324649
#>
#> $sy
#> [1] 1.106799
#>
#> $mx
#> [1] 2.619622e-18 6.035652e-19 -7.708677e-19 -1.376335e-18 2.137234e-18
#> [6] -1.450337e-18 2.499709e-18
#>
#> $sx
#> [1] 0.2118479 0.2004634 0.1916164 0.2131522 0.2018289 0.1968717 0.1998765
#>
#> attr(,"class")
#> [1] "tensor.reg"
## Predictions
pred <- predict_tensor_reg(fit, x, z)
head(pred)
#> [1] -1.07748772 0.63726048 0.09267041 2.04084337 -1.87800737 0.09874663
## Cross-Validation
cv <- cv.tensor.reg(x, z, y, ranks = 1:2, nsweep = 5)
cv
#> rank RMSE
#> 1 1 1.839541
#> 2 2 2.277003
## Scatter plot of predicted vs actual
plot(y, pred, pch = 19, col = "blue",
main = "Predicted vs Actual Response",
xlab = "Actual y", ylab = "Predicted y")
abline(a = 0, b = 1, col = "red", lty = 2)
x1 <- x[,1,1]
## Scatter plot of Predicted vs Tensor Covariate
plot(x1, pred, pch = 19, col = "purple",
main = "Predicted vs Tensor Covariate",
xlab = "Tensor Covariate", ylab = "Predicted y")
abline(lm(pred ~ x1), col = "green", lty = 2)