-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathinit.html
1360 lines (1053 loc) · 105 KB
/
init.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
<meta name="robots" content="noindex">
<meta name="robots" content="noindex">
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>torch.nn.init — PyTorch 2.0 documentation</title>
<link rel="canonical" href="https://pytorch.org/docs/stable/_modules/torch/nn/init.html"/>
<link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
<!-- <link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" /> -->
<link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../../../_static/copybutton.css" type="text/css" />
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/katex.min.css" type="text/css" />
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/katex.min.css" type="text/css" />
<link rel="stylesheet" href="../../../_static/katex-math.css" type="text/css" />
<link rel="stylesheet" href="../../../_static/sphinx-dropdown.css" type="text/css" />
<link rel="stylesheet" href="../../../_static/panels-bootstrap.min.css" type="text/css" />
<link rel="stylesheet" href="../../../_static/css/jit.css" type="text/css" />
<link rel="index" title="Index" href="../../../genindex.html" />
<link rel="search" title="Search" href="../../../search.html" />
<!-- Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=UA-117752657-2"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag(){dataLayer.push(arguments);}
gtag('js', new Date());
gtag('config', 'UA-117752657-2');
</script>
<!-- End Google Analytics -->
<script src="../../../_static/js/modernizr.min.js"></script>
<!-- Preload the theme fonts -->
<link rel="preload" href="../../../_static/fonts/FreightSans/freight-sans-book.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="../../../_static/fonts/FreightSans/freight-sans-medium.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="../../../_static/fonts/IBMPlexMono/IBMPlexMono-Medium.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="../../../_static/fonts/FreightSans/freight-sans-bold.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="../../../_static/fonts/FreightSans/freight-sans-medium-italic.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="../../../_static/fonts/IBMPlexMono/IBMPlexMono-SemiBold.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<!-- Preload the katex fonts -->
<link rel="preload" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/fonts/KaTeX_Math-Italic.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/fonts/KaTeX_Main-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/fonts/KaTeX_Main-Bold.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/fonts/KaTeX_Size1-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/fonts/KaTeX_Size4-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/fonts/KaTeX_Size2-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/fonts/KaTeX_Size3-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/fonts/KaTeX_Caligraphic-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.15.2/css/all.css" integrity="sha384-vSIIfh2YWi9wW0r9iZe7RJPrKwp6bG+s9QZMoITbCckVJqGCCRhc+ccxNcdpHuYu" crossorigin="anonymous">
</head>
<div class="container-fluid header-holder tutorials-header" id="header-holder">
<div class="container">
<div class="header-container">
<a class="header-logo" href="https://pytorch.org/" aria-label="PyTorch"></a>
<div class="main-menu">
<ul>
<li>
<a href="https://pytorch.org/get-started">Get Started</a>
</li>
<li>
<a href="https://pytorch.org/ecosystem">Ecosystem</a>
</li>
<li>
<a href="https://pytorch.org/mobile">Mobile</a>
</li>
<li>
<a href="https://pytorch.org/blog/">Blog</a>
</li>
<li>
<a href="https://pytorch.org/tutorials">Tutorials</a>
</li>
<li class="active docs-active">
<div id="resourcesDropdownButton" data-toggle="resources-dropdown" class="resources-dropdown">
<a class="resource-option with-down-orange-arrow">
Docs
</a>
<div class="resources-dropdown-menu">
<a class="doc-dropdown-option nav-dropdown-item" href="https://pytorch.org/docs/stable/index.html">
<span class="dropdown-title">PyTorch</span>
<p></p>
</a>
<a class="doc-dropdown-option nav-dropdown-item" href="https://pytorch.org/audio/stable/index.html">
<span class="dropdown-title">torchaudio</span>
<p></p>
</a>
<a class="doc-dropdown-option nav-dropdown-item" href="https://pytorch.org/text/stable/index.html">
<span class="dropdown-title">torchtext</span>
<p></p>
</a>
<a class="doc-dropdown-option nav-dropdown-item" href="https://pytorch.org/vision/stable/index.html">
<span class="dropdown-title">torchvision</span>
<p></p>
</a>
<a class="doc-dropdown-option nav-dropdown-item" href="https://pytorch.org/torcharrow">
<span class="dropdown-title">torcharrow</span>
<p></p>
</a>
<a class="doc-dropdown-option nav-dropdown-item" href="https://pytorch.org/data">
<span class="dropdown-title">TorchData</span>
<p></p>
</a>
<a class="doc-dropdown-option nav-dropdown-item" href="https://pytorch.org/torchrec">
<span class="dropdown-title">TorchRec</span>
<p></p>
</a>
<a class="doc-dropdown-option nav-dropdown-item" href="https://pytorch.org/serve/">
<span class="dropdown-title">TorchServe</span>
<p></p>
</a>
<a class="doc-dropdown-option nav-dropdown-item" href="https://pytorch.org/torchx/">
<span class="dropdown-title">TorchX</span>
<p></p>
</a>
<a class="doc-dropdown-option nav-dropdown-item" href="https://pytorch.org/xla">
<span class="dropdown-title">PyTorch on XLA Devices</span>
<p></p>
</a>
</div>
</li>
<li>
<div id="resourcesDropdownButton" data-toggle="resources-dropdown" class="resources-dropdown">
<a class="resource-option with-down-arrow">
Resources
</a>
<div class="resources-dropdown-menu">
<a class="nav-dropdown-item" href="https://pytorch.org/features">
<span class="dropdown-title">About</span>
<p>Learn about PyTorch’s features and capabilities</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/foundation">
<span class="dropdown-title">PyTorch Foundation</span>
<p>Learn about the PyTorch foundation</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/#community-module">
<span class="dropdown-title">Community</span>
<p>Join the PyTorch developer community to contribute, learn, and get your questions answered.</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/community-stories">
<span class="dropdown-title">Community Stories</span>
<p>Learn how our community solves real, everyday machine learning problems with PyTorch.</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/resources">
<span class="dropdown-title">Developer Resources</span>
<p>Find resources and get questions answered</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/events">
<span class="dropdown-title">Events</span>
<p>Find events, webinars, and podcasts</p>
</a>
<a class="nav-dropdown-item" href="https://discuss.pytorch.org/" target="_blank">
<span class="dropdown-title">Forums</span>
<p>A place to discuss PyTorch code, issues, install, research</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/hub">
<span class="dropdown-title">Models (Beta)</span>
<p>Discover, publish, and reuse pre-trained models</p>
</a>
</div>
</div>
</li>
<li>
<a href="https://github.com/pytorch/pytorch">GitHub</a>
</li>
</ul>
</div>
<a class="main-menu-open-button" href="#" data-behavior="open-mobile-menu"></a>
</div>
</div>
</div>
<body class="pytorch-body">
<div class="table-of-contents-link-wrapper">
<span>Table of Contents</span>
<a href="#" class="toggle-table-of-contents" data-behavior="toggle-table-of-contents"></a>
</div>
<nav data-toggle="wy-nav-shift" class="pytorch-left-menu" id="pytorch-left-menu">
<div class="pytorch-side-scroll">
<div class="pytorch-menu pytorch-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<div class="pytorch-left-menu-search">
<div class="version">
<a href='https://pytorch.org/docs/versions.html'>2.0 ▼</a>
</div>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
<input type="text" name="q" placeholder="Search Docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<p class="caption" role="heading"><span class="caption-text">Community</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../community/build_ci_governance.html">PyTorch Governance | Build + CI</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../community/contribution_guide.html">PyTorch Contribution Guide</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../community/design.html">PyTorch Design Philosophy</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../community/governance.html">PyTorch Governance | Mechanics</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../community/persons_of_interest.html">PyTorch Governance | Maintainers</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Developer Notes</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/amp_examples.html">CUDA Automatic Mixed Precision examples</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/autograd.html">Autograd mechanics</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/broadcasting.html">Broadcasting semantics</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/cpu_threading_torchscript_inference.html">CPU threading and TorchScript inference</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/cuda.html">CUDA semantics</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/ddp.html">Distributed Data Parallel</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/extending.html">Extending PyTorch</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/extending.func.html">Extending torch.func with autograd.Function</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/faq.html">Frequently Asked Questions</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/gradcheck.html">Gradcheck mechanics</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/hip.html">HIP (ROCm) semantics</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/large_scale_deployments.html">Features for large-scale deployments</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/modules.html">Modules</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/mps.html">MPS backend</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/multiprocessing.html">Multiprocessing best practices</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/numerical_accuracy.html">Numerical accuracy</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/randomness.html">Reproducibility</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/serialization.html">Serialization semantics</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../notes/windows.html">Windows FAQ</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">torch.compile</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../dynamo/index.html">TorchDynamo Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../dynamo/installation.html">Installing TorchDynamo</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../dynamo/get-started.html">Getting Started</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../dynamo/guards-overview.html">Guards Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../dynamo/custom-backends.html">Custom Backends</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../dynamo/deep-dive.html">TorchDynamo Deeper Dive</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../dynamo/troubleshooting.html">TorchDynamo Troubleshooting</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../dynamo/faq.html">Frequently Asked Questions</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../ir.html">IRs</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Language Bindings</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../cpp_index.html">C++</a></li>
<li class="toctree-l1"><a class="reference external" href="https://pytorch.org/javadoc/">Javadoc</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../deploy.html">torch::deploy</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../../torch.html">torch</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../nn.html">torch.nn</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../nn.functional.html">torch.nn.functional</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../tensors.html">torch.Tensor</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../tensor_attributes.html">Tensor Attributes</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../tensor_view.html">Tensor Views</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../amp.html">torch.amp</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../autograd.html">torch.autograd</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../library.html">torch.library</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../cuda.html">torch.cuda</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../mps.html">torch.mps</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../backends.html">torch.backends</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../distributed.html">torch.distributed</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../distributed.algorithms.join.html">torch.distributed.algorithms.join</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../distributed.elastic.html">torch.distributed.elastic</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../fsdp.html">torch.distributed.fsdp</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../distributed.optim.html">torch.distributed.optim</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../distributed.tensor.parallel.html">torch.distributed.tensor.parallel</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../distributed.checkpoint.html">torch.distributed.checkpoint</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../distributions.html">torch.distributions</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../_dynamo.html">torch._dynamo</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../fft.html">torch.fft</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../func.html">torch.func</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../futures.html">torch.futures</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../fx.html">torch.fx</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../hub.html">torch.hub</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../jit.html">torch.jit</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../linalg.html">torch.linalg</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../monitor.html">torch.monitor</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../signal.html">torch.signal</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../special.html">torch.special</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../torch.overrides.html">torch.overrides</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../package.html">torch.package</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../profiler.html">torch.profiler</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../nn.init.html">torch.nn.init</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../onnx.html">torch.onnx</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../onnx_diagnostics.html">torch.onnx diagnostics</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../optim.html">torch.optim</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../complex_numbers.html">Complex Numbers</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../ddp_comm_hooks.html">DDP Communication Hooks</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../pipeline.html">Pipeline Parallelism</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../quantization.html">Quantization</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../rpc.html">Distributed RPC Framework</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../random.html">torch.random</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../masked.html">torch.masked</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../nested.html">torch.nested</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../sparse.html">torch.sparse</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../storage.html">torch.Storage</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../testing.html">torch.testing</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../benchmark_utils.html">torch.utils.benchmark</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../bottleneck.html">torch.utils.bottleneck</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../checkpoint.html">torch.utils.checkpoint</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../cpp_extension.html">torch.utils.cpp_extension</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../data.html">torch.utils.data</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../jit_utils.html">torch.utils.jit</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../dlpack.html">torch.utils.dlpack</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../mobile_optimizer.html">torch.utils.mobile_optimizer</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../model_zoo.html">torch.utils.model_zoo</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../tensorboard.html">torch.utils.tensorboard</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../type_info.html">Type Info</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../named_tensor.html">Named Tensors</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../name_inference.html">Named Tensors operator coverage</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../config_mod.html">torch.__config__</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Libraries</span></p>
<ul>
<li class="toctree-l1"><a class="reference external" href="https://pytorch.org/audio/stable">torchaudio</a></li>
<li class="toctree-l1"><a class="reference external" href="https://pytorch.org/data">TorchData</a></li>
<li class="toctree-l1"><a class="reference external" href="https://pytorch.org/torchrec">TorchRec</a></li>
<li class="toctree-l1"><a class="reference external" href="https://pytorch.org/serve">TorchServe</a></li>
<li class="toctree-l1"><a class="reference external" href="https://pytorch.org/text/stable">torchtext</a></li>
<li class="toctree-l1"><a class="reference external" href="https://pytorch.org/vision/stable">torchvision</a></li>
<li class="toctree-l1"><a class="reference external" href="https://pytorch.org/xla/">PyTorch on XLA Devices</a></li>
</ul>
</div>
</div>
</nav>
<div class="pytorch-container">
<div class="pytorch-page-level-bar" id="pytorch-page-level-bar">
<div class="pytorch-breadcrumbs-wrapper">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="pytorch-breadcrumbs">
<li>
<a href="../../../index.html">
Docs
</a> >
</li>
<li><a href="../../index.html">Module code</a> ></li>
<li><a href="../../torch.html">torch</a> ></li>
<li>torch.nn.init</li>
<li class="pytorch-breadcrumbs-aside">
</li>
</ul>
</div>
</div>
<div class="pytorch-shortcuts-wrapper" id="pytorch-shortcuts-wrapper">
Shortcuts
</div>
</div>
<section data-toggle="wy-nav-shift" id="pytorch-content-wrap" class="pytorch-content-wrap">
<div class="pytorch-content-left">
<div class="rst-content">
<div role="main" class="main-content" itemscope="itemscope" itemtype="http://schema.org/Article">
<article itemprop="articleBody" id="pytorch-article" class="pytorch-article">
<h1>Source code for torch.nn.init</h1><div class="highlight"><pre>
<span></span><span class="kn">import</span> <span class="nn">math</span>
<span class="kn">import</span> <span class="nn">warnings</span>
<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">Tensor</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="c1"># These no_grad_* functions are necessary as wrappers around the parts of these</span>
<span class="c1"># functions that use `with torch.no_grad()`. The JIT doesn't support context</span>
<span class="c1"># managers, so these need to be implemented as builtins. Using these wrappers</span>
<span class="c1"># lets us keep those builtins small and re-usable.</span>
<span class="k">def</span> <span class="nf">_no_grad_uniform_</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="k">return</span> <span class="n">tensor</span><span class="o">.</span><span class="n">uniform_</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">_no_grad_normal_</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">mean</span><span class="p">,</span> <span class="n">std</span><span class="p">):</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="k">return</span> <span class="n">tensor</span><span class="o">.</span><span class="n">normal_</span><span class="p">(</span><span class="n">mean</span><span class="p">,</span> <span class="n">std</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">_no_grad_trunc_normal_</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">mean</span><span class="p">,</span> <span class="n">std</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
<span class="c1"># Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf</span>
<span class="k">def</span> <span class="nf">norm_cdf</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="c1"># Computes standard normal cumulative distribution function</span>
<span class="k">return</span> <span class="p">(</span><span class="mf">1.</span> <span class="o">+</span> <span class="n">math</span><span class="o">.</span><span class="n">erf</span><span class="p">(</span><span class="n">x</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mf">2.</span><span class="p">)))</span> <span class="o">/</span> <span class="mf">2.</span>
<span class="k">if</span> <span class="p">(</span><span class="n">mean</span> <span class="o"><</span> <span class="n">a</span> <span class="o">-</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">std</span><span class="p">)</span> <span class="ow">or</span> <span class="p">(</span><span class="n">mean</span> <span class="o">></span> <span class="n">b</span> <span class="o">+</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">std</span><span class="p">):</span>
<span class="n">warnings</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span><span class="s2">"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "</span>
<span class="s2">"The distribution of values may be incorrect."</span><span class="p">,</span>
<span class="n">stacklevel</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="c1"># Values are generated by using a truncated uniform distribution and</span>
<span class="c1"># then using the inverse CDF for the normal distribution.</span>
<span class="c1"># Get upper and lower cdf values</span>
<span class="n">l</span> <span class="o">=</span> <span class="n">norm_cdf</span><span class="p">((</span><span class="n">a</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span> <span class="o">/</span> <span class="n">std</span><span class="p">)</span>
<span class="n">u</span> <span class="o">=</span> <span class="n">norm_cdf</span><span class="p">((</span><span class="n">b</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span> <span class="o">/</span> <span class="n">std</span><span class="p">)</span>
<span class="c1"># Uniformly fill tensor with values from [l, u], then translate to</span>
<span class="c1"># [2l-1, 2u-1].</span>
<span class="n">tensor</span><span class="o">.</span><span class="n">uniform_</span><span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">l</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">u</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
<span class="c1"># Use inverse cdf transform for normal distribution to get truncated</span>
<span class="c1"># standard normal</span>
<span class="n">tensor</span><span class="o">.</span><span class="n">erfinv_</span><span class="p">()</span>
<span class="c1"># Transform to proper mean, std</span>
<span class="n">tensor</span><span class="o">.</span><span class="n">mul_</span><span class="p">(</span><span class="n">std</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mf">2.</span><span class="p">))</span>
<span class="n">tensor</span><span class="o">.</span><span class="n">add_</span><span class="p">(</span><span class="n">mean</span><span class="p">)</span>
<span class="c1"># Clamp to ensure it's in the proper range</span>
<span class="n">tensor</span><span class="o">.</span><span class="n">clamp_</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="n">a</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="n">b</span><span class="p">)</span>
<span class="k">return</span> <span class="n">tensor</span>
<span class="k">def</span> <span class="nf">_no_grad_fill_</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">val</span><span class="p">):</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="k">return</span> <span class="n">tensor</span><span class="o">.</span><span class="n">fill_</span><span class="p">(</span><span class="n">val</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">_no_grad_zero_</span><span class="p">(</span><span class="n">tensor</span><span class="p">):</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="k">return</span> <span class="n">tensor</span><span class="o">.</span><span class="n">zero_</span><span class="p">()</span>
<div class="viewcode-block" id="calculate_gain"><a class="viewcode-back" href="../../../nn.init.html#torch.nn.init.calculate_gain">[docs]</a><span class="k">def</span> <span class="nf">calculate_gain</span><span class="p">(</span><span class="n">nonlinearity</span><span class="p">,</span> <span class="n">param</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="sa">r</span><span class="sd">"""Return the recommended gain value for the given nonlinearity function.</span>
<span class="sd"> The values are as follows:</span>
<span class="sd"> ================= ====================================================</span>
<span class="sd"> nonlinearity gain</span>
<span class="sd"> ================= ====================================================</span>
<span class="sd"> Linear / Identity :math:`1`</span>
<span class="sd"> Conv{1,2,3}D :math:`1`</span>
<span class="sd"> Sigmoid :math:`1`</span>
<span class="sd"> Tanh :math:`\frac{5}{3}`</span>
<span class="sd"> ReLU :math:`\sqrt{2}`</span>
<span class="sd"> Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`</span>
<span class="sd"> SELU :math:`\frac{3}{4}`</span>
<span class="sd"> ================= ====================================================</span>
<span class="sd"> .. warning::</span>
<span class="sd"> In order to implement `Self-Normalizing Neural Networks`_ ,</span>
<span class="sd"> you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``.</span>
<span class="sd"> This gives the initial weights a variance of ``1 / N``,</span>
<span class="sd"> which is necessary to induce a stable fixed point in the forward pass.</span>
<span class="sd"> In contrast, the default gain for ``SELU`` sacrifices the normalisation</span>
<span class="sd"> effect for more stable gradient flow in rectangular layers.</span>
<span class="sd"> Args:</span>
<span class="sd"> nonlinearity: the non-linear function (`nn.functional` name)</span>
<span class="sd"> param: optional parameter for the non-linear function</span>
<span class="sd"> Examples:</span>
<span class="sd"> >>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2</span>
<span class="sd"> .. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html</span>
<span class="sd"> """</span>
<span class="n">linear_fns</span> <span class="o">=</span> <span class="p">[</span><span class="s1">'linear'</span><span class="p">,</span> <span class="s1">'conv1d'</span><span class="p">,</span> <span class="s1">'conv2d'</span><span class="p">,</span> <span class="s1">'conv3d'</span><span class="p">,</span> <span class="s1">'conv_transpose1d'</span><span class="p">,</span> <span class="s1">'conv_transpose2d'</span><span class="p">,</span> <span class="s1">'conv_transpose3d'</span><span class="p">]</span>
<span class="k">if</span> <span class="n">nonlinearity</span> <span class="ow">in</span> <span class="n">linear_fns</span> <span class="ow">or</span> <span class="n">nonlinearity</span> <span class="o">==</span> <span class="s1">'sigmoid'</span><span class="p">:</span>
<span class="k">return</span> <span class="mi">1</span>
<span class="k">elif</span> <span class="n">nonlinearity</span> <span class="o">==</span> <span class="s1">'tanh'</span><span class="p">:</span>
<span class="k">return</span> <span class="mf">5.0</span> <span class="o">/</span> <span class="mi">3</span>
<span class="k">elif</span> <span class="n">nonlinearity</span> <span class="o">==</span> <span class="s1">'relu'</span><span class="p">:</span>
<span class="k">return</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mf">2.0</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">nonlinearity</span> <span class="o">==</span> <span class="s1">'leaky_relu'</span><span class="p">:</span>
<span class="k">if</span> <span class="n">param</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">negative_slope</span> <span class="o">=</span> <span class="mf">0.01</span>
<span class="k">elif</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">param</span><span class="p">,</span> <span class="nb">bool</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">param</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span> <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">param</span><span class="p">,</span> <span class="nb">float</span><span class="p">):</span>
<span class="c1"># True/False are instances of int, hence check above</span>
<span class="n">negative_slope</span> <span class="o">=</span> <span class="n">param</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"negative_slope </span><span class="si">{}</span><span class="s2"> not a valid number"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">param</span><span class="p">))</span>
<span class="k">return</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mf">2.0</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">negative_slope</span> <span class="o">**</span> <span class="mi">2</span><span class="p">))</span>
<span class="k">elif</span> <span class="n">nonlinearity</span> <span class="o">==</span> <span class="s1">'selu'</span><span class="p">:</span>
<span class="k">return</span> <span class="mf">3.0</span> <span class="o">/</span> <span class="mi">4</span> <span class="c1"># Value found empirically (https://github.com/pytorch/pytorch/pull/50664)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"Unsupported nonlinearity </span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">nonlinearity</span><span class="p">))</span></div>
<div class="viewcode-block" id="uniform_"><a class="viewcode-back" href="../../../nn.init.html#torch.nn.init.uniform_">[docs]</a><span class="k">def</span> <span class="nf">uniform_</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">a</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.</span><span class="p">,</span> <span class="n">b</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
<span class="sa">r</span><span class="sd">"""Fills the input Tensor with values drawn from the uniform</span>
<span class="sd"> distribution :math:`\mathcal{U}(a, b)`.</span>
<span class="sd"> Args:</span>
<span class="sd"> tensor: an n-dimensional `torch.Tensor`</span>
<span class="sd"> a: the lower bound of the uniform distribution</span>
<span class="sd"> b: the upper bound of the uniform distribution</span>
<span class="sd"> Examples:</span>
<span class="sd"> >>> w = torch.empty(3, 5)</span>
<span class="sd"> >>> nn.init.uniform_(w)</span>
<span class="sd"> """</span>
<span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">overrides</span><span class="o">.</span><span class="n">has_torch_function_variadic</span><span class="p">(</span><span class="n">tensor</span><span class="p">):</span>
<span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">overrides</span><span class="o">.</span><span class="n">handle_torch_function</span><span class="p">(</span><span class="n">uniform_</span><span class="p">,</span> <span class="p">(</span><span class="n">tensor</span><span class="p">,),</span> <span class="n">tensor</span><span class="o">=</span><span class="n">tensor</span><span class="p">,</span> <span class="n">a</span><span class="o">=</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="o">=</span><span class="n">b</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_no_grad_uniform_</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span></div>
<div class="viewcode-block" id="normal_"><a class="viewcode-back" href="../../../nn.init.html#torch.nn.init.normal_">[docs]</a><span class="k">def</span> <span class="nf">normal_</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">mean</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.</span><span class="p">,</span> <span class="n">std</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
<span class="sa">r</span><span class="sd">"""Fills the input Tensor with values drawn from the normal</span>
<span class="sd"> distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.</span>
<span class="sd"> Args:</span>
<span class="sd"> tensor: an n-dimensional `torch.Tensor`</span>
<span class="sd"> mean: the mean of the normal distribution</span>
<span class="sd"> std: the standard deviation of the normal distribution</span>
<span class="sd"> Examples:</span>
<span class="sd"> >>> w = torch.empty(3, 5)</span>
<span class="sd"> >>> nn.init.normal_(w)</span>
<span class="sd"> """</span>
<span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">overrides</span><span class="o">.</span><span class="n">has_torch_function_variadic</span><span class="p">(</span><span class="n">tensor</span><span class="p">):</span>
<span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">overrides</span><span class="o">.</span><span class="n">handle_torch_function</span><span class="p">(</span><span class="n">normal_</span><span class="p">,</span> <span class="p">(</span><span class="n">tensor</span><span class="p">,),</span> <span class="n">tensor</span><span class="o">=</span><span class="n">tensor</span><span class="p">,</span> <span class="n">mean</span><span class="o">=</span><span class="n">mean</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="n">std</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_no_grad_normal_</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">mean</span><span class="p">,</span> <span class="n">std</span><span class="p">)</span></div>
<div class="viewcode-block" id="trunc_normal_"><a class="viewcode-back" href="../../../nn.init.html#torch.nn.init.trunc_normal_">[docs]</a><span class="k">def</span> <span class="nf">trunc_normal_</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">mean</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.</span><span class="p">,</span> <span class="n">std</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.</span><span class="p">,</span> <span class="n">a</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">-</span><span class="mf">2.</span><span class="p">,</span> <span class="n">b</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">2.</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
<span class="sa">r</span><span class="sd">"""Fills the input Tensor with values drawn from a truncated</span>
<span class="sd"> normal distribution. The values are effectively drawn from the</span>
<span class="sd"> normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`</span>
<span class="sd"> with values outside :math:`[a, b]` redrawn until they are within</span>
<span class="sd"> the bounds. The method used for generating the random values works</span>
<span class="sd"> best when :math:`a \leq \text{mean} \leq b`.</span>
<span class="sd"> Args:</span>
<span class="sd"> tensor: an n-dimensional `torch.Tensor`</span>
<span class="sd"> mean: the mean of the normal distribution</span>
<span class="sd"> std: the standard deviation of the normal distribution</span>
<span class="sd"> a: the minimum cutoff value</span>
<span class="sd"> b: the maximum cutoff value</span>
<span class="sd"> Examples:</span>
<span class="sd"> >>> w = torch.empty(3, 5)</span>
<span class="sd"> >>> nn.init.trunc_normal_(w)</span>
<span class="sd"> """</span>
<span class="k">return</span> <span class="n">_no_grad_trunc_normal_</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">mean</span><span class="p">,</span> <span class="n">std</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span></div>
<div class="viewcode-block" id="constant_"><a class="viewcode-back" href="../../../nn.init.html#torch.nn.init.constant_">[docs]</a><span class="k">def</span> <span class="nf">constant_</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">val</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
<span class="sa">r</span><span class="sd">"""Fills the input Tensor with the value :math:`\text{val}`.</span>
<span class="sd"> Args:</span>
<span class="sd"> tensor: an n-dimensional `torch.Tensor`</span>
<span class="sd"> val: the value to fill the tensor with</span>
<span class="sd"> Examples:</span>
<span class="sd"> >>> w = torch.empty(3, 5)</span>
<span class="sd"> >>> nn.init.constant_(w, 0.3)</span>
<span class="sd"> """</span>
<span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">overrides</span><span class="o">.</span><span class="n">has_torch_function_variadic</span><span class="p">(</span><span class="n">tensor</span><span class="p">):</span>
<span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">overrides</span><span class="o">.</span><span class="n">handle_torch_function</span><span class="p">(</span><span class="n">constant_</span><span class="p">,</span> <span class="p">(</span><span class="n">tensor</span><span class="p">,),</span> <span class="n">tensor</span><span class="o">=</span><span class="n">tensor</span><span class="p">,</span> <span class="n">val</span><span class="o">=</span><span class="n">val</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_no_grad_fill_</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">val</span><span class="p">)</span></div>
<div class="viewcode-block" id="ones_"><a class="viewcode-back" href="../../../nn.init.html#torch.nn.init.ones_">[docs]</a><span class="k">def</span> <span class="nf">ones_</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
<span class="sa">r</span><span class="sd">"""Fills the input Tensor with the scalar value `1`.</span>
<span class="sd"> Args:</span>
<span class="sd"> tensor: an n-dimensional `torch.Tensor`</span>
<span class="sd"> Examples:</span>
<span class="sd"> >>> w = torch.empty(3, 5)</span>
<span class="sd"> >>> nn.init.ones_(w)</span>
<span class="sd"> """</span>
<span class="k">return</span> <span class="n">_no_grad_fill_</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="mf">1.</span><span class="p">)</span></div>
<div class="viewcode-block" id="zeros_"><a class="viewcode-back" href="../../../nn.init.html#torch.nn.init.zeros_">[docs]</a><span class="k">def</span> <span class="nf">zeros_</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
<span class="sa">r</span><span class="sd">"""Fills the input Tensor with the scalar value `0`.</span>
<span class="sd"> Args:</span>
<span class="sd"> tensor: an n-dimensional `torch.Tensor`</span>
<span class="sd"> Examples:</span>
<span class="sd"> >>> w = torch.empty(3, 5)</span>
<span class="sd"> >>> nn.init.zeros_(w)</span>
<span class="sd"> """</span>
<span class="k">return</span> <span class="n">_no_grad_zero_</span><span class="p">(</span><span class="n">tensor</span><span class="p">)</span></div>
<div class="viewcode-block" id="eye_"><a class="viewcode-back" href="../../../nn.init.html#torch.nn.init.eye_">[docs]</a><span class="k">def</span> <span class="nf">eye_</span><span class="p">(</span><span class="n">tensor</span><span class="p">):</span>
<span class="sa">r</span><span class="sd">"""Fills the 2-dimensional input `Tensor` with the identity</span>
<span class="sd"> matrix. Preserves the identity of the inputs in `Linear` layers, where as</span>
<span class="sd"> many inputs are preserved as possible.</span>
<span class="sd"> Args:</span>
<span class="sd"> tensor: a 2-dimensional `torch.Tensor`</span>
<span class="sd"> Examples:</span>
<span class="sd"> >>> w = torch.empty(3, 5)</span>
<span class="sd"> >>> nn.init.eye_(w)</span>
<span class="sd"> """</span>
<span class="k">if</span> <span class="n">tensor</span><span class="o">.</span><span class="n">ndimension</span><span class="p">()</span> <span class="o">!=</span> <span class="mi">2</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"Only tensors with 2 dimensions are supported"</span><span class="p">)</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="n">torch</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="o">*</span><span class="n">tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">out</span><span class="o">=</span><span class="n">tensor</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="n">tensor</span><span class="o">.</span><span class="n">requires_grad</span><span class="p">)</span>
<span class="k">return</span> <span class="n">tensor</span></div>
<div class="viewcode-block" id="dirac_"><a class="viewcode-back" href="../../../nn.init.html#torch.nn.init.dirac_">[docs]</a><span class="k">def</span> <span class="nf">dirac_</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">groups</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
<span class="sa">r</span><span class="sd">"""Fills the {3, 4, 5}-dimensional input `Tensor` with the Dirac</span>
<span class="sd"> delta function. Preserves the identity of the inputs in `Convolutional`</span>
<span class="sd"> layers, where as many input channels are preserved as possible. In case</span>
<span class="sd"> of groups>1, each group of channels preserves identity</span>
<span class="sd"> Args:</span>
<span class="sd"> tensor: a {3, 4, 5}-dimensional `torch.Tensor`</span>
<span class="sd"> groups (int, optional): number of groups in the conv layer (default: 1)</span>
<span class="sd"> Examples:</span>
<span class="sd"> >>> w = torch.empty(3, 16, 5, 5)</span>
<span class="sd"> >>> nn.init.dirac_(w)</span>
<span class="sd"> >>> w = torch.empty(3, 24, 5, 5)</span>
<span class="sd"> >>> nn.init.dirac_(w, 3)</span>
<span class="sd"> """</span>
<span class="n">dimensions</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">ndimension</span><span class="p">()</span>
<span class="k">if</span> <span class="n">dimensions</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">]:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"Only tensors with 3, 4, or 5 dimensions are supported"</span><span class="p">)</span>
<span class="n">sizes</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">size</span><span class="p">()</span>
<span class="k">if</span> <span class="n">sizes</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">%</span> <span class="n">groups</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">'dim 0 must be divisible by groups'</span><span class="p">)</span>
<span class="n">out_chans_per_grp</span> <span class="o">=</span> <span class="n">sizes</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">//</span> <span class="n">groups</span>
<span class="n">min_dim</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">out_chans_per_grp</span><span class="p">,</span> <span class="n">sizes</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="n">tensor</span><span class="o">.</span><span class="n">zero_</span><span class="p">()</span>
<span class="k">for</span> <span class="n">g</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">groups</span><span class="p">):</span>
<span class="k">for</span> <span class="n">d</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">min_dim</span><span class="p">):</span>
<span class="k">if</span> <span class="n">dimensions</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span> <span class="c1"># Temporal convolution</span>
<span class="n">tensor</span><span class="p">[</span><span class="n">g</span> <span class="o">*</span> <span class="n">out_chans_per_grp</span> <span class="o">+</span> <span class="n">d</span><span class="p">,</span> <span class="n">d</span><span class="p">,</span> <span class="n">tensor</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span> <span class="o">//</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
<span class="k">elif</span> <span class="n">dimensions</span> <span class="o">==</span> <span class="mi">4</span><span class="p">:</span> <span class="c1"># Spatial convolution</span>
<span class="n">tensor</span><span class="p">[</span><span class="n">g</span> <span class="o">*</span> <span class="n">out_chans_per_grp</span> <span class="o">+</span> <span class="n">d</span><span class="p">,</span> <span class="n">d</span><span class="p">,</span> <span class="n">tensor</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span> <span class="o">//</span> <span class="mi">2</span><span class="p">,</span>
<span class="n">tensor</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span> <span class="o">//</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
<span class="k">else</span><span class="p">:</span> <span class="c1"># Volumetric convolution</span>
<span class="n">tensor</span><span class="p">[</span><span class="n">g</span> <span class="o">*</span> <span class="n">out_chans_per_grp</span> <span class="o">+</span> <span class="n">d</span><span class="p">,</span> <span class="n">d</span><span class="p">,</span> <span class="n">tensor</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span> <span class="o">//</span> <span class="mi">2</span><span class="p">,</span>
<span class="n">tensor</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span> <span class="o">//</span> <span class="mi">2</span><span class="p">,</span> <span class="n">tensor</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">4</span><span class="p">)</span> <span class="o">//</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
<span class="k">return</span> <span class="n">tensor</span></div>
<span class="k">def</span> <span class="nf">_calculate_fan_in_and_fan_out</span><span class="p">(</span><span class="n">tensor</span><span class="p">):</span>
<span class="n">dimensions</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span>
<span class="k">if</span> <span class="n">dimensions</span> <span class="o"><</span> <span class="mi">2</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"</span><span class="p">)</span>
<span class="n">num_input_fmaps</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
<span class="n">num_output_fmaps</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">receptive_field_size</span> <span class="o">=</span> <span class="mi">1</span>
<span class="k">if</span> <span class="n">tensor</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">></span> <span class="mi">2</span><span class="p">:</span>
<span class="c1"># math.prod is not always available, accumulate the product manually</span>
<span class="c1"># we could use functools.reduce but that is not supported by TorchScript</span>
<span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="n">tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">:]:</span>
<span class="n">receptive_field_size</span> <span class="o">*=</span> <span class="n">s</span>
<span class="n">fan_in</span> <span class="o">=</span> <span class="n">num_input_fmaps</span> <span class="o">*</span> <span class="n">receptive_field_size</span>
<span class="n">fan_out</span> <span class="o">=</span> <span class="n">num_output_fmaps</span> <span class="o">*</span> <span class="n">receptive_field_size</span>
<span class="k">return</span> <span class="n">fan_in</span><span class="p">,</span> <span class="n">fan_out</span>
<div class="viewcode-block" id="xavier_uniform_"><a class="viewcode-back" href="../../../nn.init.html#torch.nn.init.xavier_uniform_">[docs]</a><span class="k">def</span> <span class="nf">xavier_uniform_</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">gain</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
<span class="sa">r</span><span class="sd">"""Fills the input `Tensor` with values according to the method</span>
<span class="sd"> described in `Understanding the difficulty of training deep feedforward</span>
<span class="sd"> neural networks` - Glorot, X. & Bengio, Y. (2010), using a uniform</span>
<span class="sd"> distribution. The resulting tensor will have values sampled from</span>
<span class="sd"> :math:`\mathcal{U}(-a, a)` where</span>
<span class="sd"> .. math::</span>
<span class="sd"> a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}}</span>
<span class="sd"> Also known as Glorot initialization.</span>
<span class="sd"> Args:</span>
<span class="sd"> tensor: an n-dimensional `torch.Tensor`</span>
<span class="sd"> gain: an optional scaling factor</span>
<span class="sd"> Examples:</span>
<span class="sd"> >>> w = torch.empty(3, 5)</span>
<span class="sd"> >>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))</span>
<span class="sd"> """</span>
<span class="n">fan_in</span><span class="p">,</span> <span class="n">fan_out</span> <span class="o">=</span> <span class="n">_calculate_fan_in_and_fan_out</span><span class="p">(</span><span class="n">tensor</span><span class="p">)</span>
<span class="n">std</span> <span class="o">=</span> <span class="n">gain</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mf">2.0</span> <span class="o">/</span> <span class="nb">float</span><span class="p">(</span><span class="n">fan_in</span> <span class="o">+</span> <span class="n">fan_out</span><span class="p">))</span>
<span class="n">a</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mf">3.0</span><span class="p">)</span> <span class="o">*</span> <span class="n">std</span> <span class="c1"># Calculate uniform bounds from standard deviation</span>
<span class="k">return</span> <span class="n">_no_grad_uniform_</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="o">-</span><span class="n">a</span><span class="p">,</span> <span class="n">a</span><span class="p">)</span></div>
<div class="viewcode-block" id="xavier_normal_"><a class="viewcode-back" href="../../../nn.init.html#torch.nn.init.xavier_normal_">[docs]</a><span class="k">def</span> <span class="nf">xavier_normal_</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">gain</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
<span class="sa">r</span><span class="sd">"""Fills the input `Tensor` with values according to the method</span>
<span class="sd"> described in `Understanding the difficulty of training deep feedforward</span>
<span class="sd"> neural networks` - Glorot, X. & Bengio, Y. (2010), using a normal</span>
<span class="sd"> distribution. The resulting tensor will have values sampled from</span>
<span class="sd"> :math:`\mathcal{N}(0, \text{std}^2)` where</span>
<span class="sd"> .. math::</span>
<span class="sd"> \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}}</span>
<span class="sd"> Also known as Glorot initialization.</span>
<span class="sd"> Args:</span>
<span class="sd"> tensor: an n-dimensional `torch.Tensor`</span>
<span class="sd"> gain: an optional scaling factor</span>
<span class="sd"> Examples:</span>
<span class="sd"> >>> w = torch.empty(3, 5)</span>
<span class="sd"> >>> nn.init.xavier_normal_(w)</span>
<span class="sd"> """</span>
<span class="n">fan_in</span><span class="p">,</span> <span class="n">fan_out</span> <span class="o">=</span> <span class="n">_calculate_fan_in_and_fan_out</span><span class="p">(</span><span class="n">tensor</span><span class="p">)</span>
<span class="n">std</span> <span class="o">=</span> <span class="n">gain</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mf">2.0</span> <span class="o">/</span> <span class="nb">float</span><span class="p">(</span><span class="n">fan_in</span> <span class="o">+</span> <span class="n">fan_out</span><span class="p">))</span>
<span class="k">return</span> <span class="n">_no_grad_normal_</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="mf">0.</span><span class="p">,</span> <span class="n">std</span><span class="p">)</span></div>
<span class="k">def</span> <span class="nf">_calculate_correct_fan</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">mode</span><span class="p">):</span>
<span class="n">mode</span> <span class="o">=</span> <span class="n">mode</span><span class="o">.</span><span class="n">lower</span><span class="p">()</span>
<span class="n">valid_modes</span> <span class="o">=</span> <span class="p">[</span><span class="s1">'fan_in'</span><span class="p">,</span> <span class="s1">'fan_out'</span><span class="p">]</span>
<span class="k">if</span> <span class="n">mode</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">valid_modes</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"Mode </span><span class="si">{}</span><span class="s2"> not supported, please use one of </span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">mode</span><span class="p">,</span> <span class="n">valid_modes</span><span class="p">))</span>
<span class="n">fan_in</span><span class="p">,</span> <span class="n">fan_out</span> <span class="o">=</span> <span class="n">_calculate_fan_in_and_fan_out</span><span class="p">(</span><span class="n">tensor</span><span class="p">)</span>
<span class="k">return</span> <span class="n">fan_in</span> <span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s1">'fan_in'</span> <span class="k">else</span> <span class="n">fan_out</span>
<div class="viewcode-block" id="kaiming_uniform_"><a class="viewcode-back" href="../../../nn.init.html#torch.nn.init.kaiming_uniform_">[docs]</a><span class="k">def</span> <span class="nf">kaiming_uniform_</span><span class="p">(</span>
<span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">a</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="n">mode</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">'fan_in'</span><span class="p">,</span> <span class="n">nonlinearity</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">'leaky_relu'</span>
<span class="p">):</span>
<span class="sa">r</span><span class="sd">"""Fills the input `Tensor` with values according to the method</span>
<span class="sd"> described in `Delving deep into rectifiers: Surpassing human-level</span>
<span class="sd"> performance on ImageNet classification` - He, K. et al. (2015), using a</span>
<span class="sd"> uniform distribution. The resulting tensor will have values sampled from</span>
<span class="sd"> :math:`\mathcal{U}(-\text{bound}, \text{bound})` where</span>
<span class="sd"> .. math::</span>
<span class="sd"> \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}</span>
<span class="sd"> Also known as He initialization.</span>
<span class="sd"> Args:</span>
<span class="sd"> tensor: an n-dimensional `torch.Tensor`</span>
<span class="sd"> a: the negative slope of the rectifier used after this layer (only</span>
<span class="sd"> used with ``'leaky_relu'``)</span>
<span class="sd"> mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``</span>
<span class="sd"> preserves the magnitude of the variance of the weights in the</span>
<span class="sd"> forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the</span>
<span class="sd"> backwards pass.</span>
<span class="sd"> nonlinearity: the non-linear function (`nn.functional` name),</span>
<span class="sd"> recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).</span>
<span class="sd"> Examples:</span>
<span class="sd"> >>> w = torch.empty(3, 5)</span>
<span class="sd"> >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')</span>
<span class="sd"> """</span>
<span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">overrides</span><span class="o">.</span><span class="n">has_torch_function_variadic</span><span class="p">(</span><span class="n">tensor</span><span class="p">):</span>
<span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">overrides</span><span class="o">.</span><span class="n">handle_torch_function</span><span class="p">(</span>
<span class="n">kaiming_uniform_</span><span class="p">,</span>
<span class="p">(</span><span class="n">tensor</span><span class="p">,),</span>
<span class="n">tensor</span><span class="o">=</span><span class="n">tensor</span><span class="p">,</span>
<span class="n">a</span><span class="o">=</span><span class="n">a</span><span class="p">,</span>
<span class="n">mode</span><span class="o">=</span><span class="n">mode</span><span class="p">,</span>
<span class="n">nonlinearity</span><span class="o">=</span><span class="n">nonlinearity</span><span class="p">)</span>
<span class="k">if</span> <span class="mi">0</span> <span class="ow">in</span> <span class="n">tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">:</span>
<span class="n">warnings</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span><span class="s2">"Initializing zero-element tensors is a no-op"</span><span class="p">)</span>
<span class="k">return</span> <span class="n">tensor</span>
<span class="n">fan</span> <span class="o">=</span> <span class="n">_calculate_correct_fan</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">mode</span><span class="p">)</span>
<span class="n">gain</span> <span class="o">=</span> <span class="n">calculate_gain</span><span class="p">(</span><span class="n">nonlinearity</span><span class="p">,</span> <span class="n">a</span><span class="p">)</span>
<span class="n">std</span> <span class="o">=</span> <span class="n">gain</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">fan</span><span class="p">)</span>
<span class="n">bound</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mf">3.0</span><span class="p">)</span> <span class="o">*</span> <span class="n">std</span> <span class="c1"># Calculate uniform bounds from standard deviation</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="k">return</span> <span class="n">tensor</span><span class="o">.</span><span class="n">uniform_</span><span class="p">(</span><span class="o">-</span><span class="n">bound</span><span class="p">,</span> <span class="n">bound</span><span class="p">)</span></div>
<div class="viewcode-block" id="kaiming_normal_"><a class="viewcode-back" href="../../../nn.init.html#torch.nn.init.kaiming_normal_">[docs]</a><span class="k">def</span> <span class="nf">kaiming_normal_</span><span class="p">(</span>
<span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">a</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="n">mode</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">'fan_in'</span><span class="p">,</span> <span class="n">nonlinearity</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">'leaky_relu'</span>
<span class="p">):</span>
<span class="sa">r</span><span class="sd">"""Fills the input `Tensor` with values according to the method</span>
<span class="sd"> described in `Delving deep into rectifiers: Surpassing human-level</span>
<span class="sd"> performance on ImageNet classification` - He, K. et al. (2015), using a</span>
<span class="sd"> normal distribution. The resulting tensor will have values sampled from</span>
<span class="sd"> :math:`\mathcal{N}(0, \text{std}^2)` where</span>
<span class="sd"> .. math::</span>
<span class="sd"> \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}</span>
<span class="sd"> Also known as He initialization.</span>
<span class="sd"> Args:</span>
<span class="sd"> tensor: an n-dimensional `torch.Tensor`</span>
<span class="sd"> a: the negative slope of the rectifier used after this layer (only</span>
<span class="sd"> used with ``'leaky_relu'``)</span>
<span class="sd"> mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``</span>
<span class="sd"> preserves the magnitude of the variance of the weights in the</span>
<span class="sd"> forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the</span>
<span class="sd"> backwards pass.</span>
<span class="sd"> nonlinearity: the non-linear function (`nn.functional` name),</span>
<span class="sd"> recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).</span>
<span class="sd"> Examples:</span>
<span class="sd"> >>> w = torch.empty(3, 5)</span>
<span class="sd"> >>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')</span>
<span class="sd"> """</span>
<span class="k">if</span> <span class="mi">0</span> <span class="ow">in</span> <span class="n">tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">:</span>
<span class="n">warnings</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span><span class="s2">"Initializing zero-element tensors is a no-op"</span><span class="p">)</span>
<span class="k">return</span> <span class="n">tensor</span>
<span class="n">fan</span> <span class="o">=</span> <span class="n">_calculate_correct_fan</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">mode</span><span class="p">)</span>
<span class="n">gain</span> <span class="o">=</span> <span class="n">calculate_gain</span><span class="p">(</span><span class="n">nonlinearity</span><span class="p">,</span> <span class="n">a</span><span class="p">)</span>
<span class="n">std</span> <span class="o">=</span> <span class="n">gain</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">fan</span><span class="p">)</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="k">return</span> <span class="n">tensor</span><span class="o">.</span><span class="n">normal_</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">std</span><span class="p">)</span></div>
<div class="viewcode-block" id="orthogonal_"><a class="viewcode-back" href="../../../nn.init.html#torch.nn.init.orthogonal_">[docs]</a><span class="k">def</span> <span class="nf">orthogonal_</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">gain</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
<span class="sa">r</span><span class="sd">"""Fills the input `Tensor` with a (semi) orthogonal matrix, as</span>
<span class="sd"> described in `Exact solutions to the nonlinear dynamics of learning in deep</span>
<span class="sd"> linear neural networks` - Saxe, A. et al. (2013). The input tensor must have</span>
<span class="sd"> at least 2 dimensions, and for tensors with more than 2 dimensions the</span>
<span class="sd"> trailing dimensions are flattened.</span>
<span class="sd"> Args:</span>
<span class="sd"> tensor: an n-dimensional `torch.Tensor`, where :math:`n \geq 2`</span>
<span class="sd"> gain: optional scaling factor</span>
<span class="sd"> Examples:</span>
<span class="sd"> >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)</span>
<span class="sd"> >>> w = torch.empty(3, 5)</span>
<span class="sd"> >>> nn.init.orthogonal_(w)</span>
<span class="sd"> """</span>
<span class="k">if</span> <span class="n">tensor</span><span class="o">.</span><span class="n">ndimension</span><span class="p">()</span> <span class="o"><</span> <span class="mi">2</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"Only tensors with 2 or more dimensions are supported"</span><span class="p">)</span>
<span class="k">if</span> <span class="n">tensor</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="c1"># no-op</span>
<span class="k">return</span> <span class="n">tensor</span>
<span class="n">rows</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">cols</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span> <span class="o">//</span> <span class="n">rows</span>
<span class="n">flattened</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">new</span><span class="p">(</span><span class="n">rows</span><span class="p">,</span> <span class="n">cols</span><span class="p">)</span><span class="o">.</span><span class="n">normal_</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="k">if</span> <span class="n">rows</span> <span class="o"><</span> <span class="n">cols</span><span class="p">:</span>
<span class="n">flattened</span><span class="o">.</span><span class="n">t_</span><span class="p">()</span>
<span class="c1"># Compute the qr factorization</span>
<span class="n">q</span><span class="p">,</span> <span class="n">r</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">linalg</span><span class="o">.</span><span class="n">qr</span><span class="p">(</span><span class="n">flattened</span><span class="p">)</span>
<span class="c1"># Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf</span>
<span class="n">d</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">diag</span><span class="p">(</span><span class="n">r</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">ph</span> <span class="o">=</span> <span class="n">d</span><span class="o">.</span><span class="n">sign</span><span class="p">()</span>
<span class="n">q</span> <span class="o">*=</span> <span class="n">ph</span>
<span class="k">if</span> <span class="n">rows</span> <span class="o"><</span> <span class="n">cols</span><span class="p">:</span>
<span class="n">q</span><span class="o">.</span><span class="n">t_</span><span class="p">()</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="n">tensor</span><span class="o">.</span><span class="n">view_as</span><span class="p">(</span><span class="n">q</span><span class="p">)</span><span class="o">.</span><span class="n">copy_</span><span class="p">(</span><span class="n">q</span><span class="p">)</span>
<span class="n">tensor</span><span class="o">.</span><span class="n">mul_</span><span class="p">(</span><span class="n">gain</span><span class="p">)</span>
<span class="k">return</span> <span class="n">tensor</span></div>
<div class="viewcode-block" id="sparse_"><a class="viewcode-back" href="../../../nn.init.html#torch.nn.init.sparse_">[docs]</a><span class="k">def</span> <span class="nf">sparse_</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">sparsity</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mf">0.01</span><span class="p">):</span>
<span class="sa">r</span><span class="sd">"""Fills the 2D input `Tensor` as a sparse matrix, where the</span>
<span class="sd"> non-zero elements will be drawn from the normal distribution</span>
<span class="sd"> :math:`\mathcal{N}(0, 0.01)`, as described in `Deep learning via</span>
<span class="sd"> Hessian-free optimization` - Martens, J. (2010).</span>
<span class="sd"> Args:</span>
<span class="sd"> tensor: an n-dimensional `torch.Tensor`</span>
<span class="sd"> sparsity: The fraction of elements in each column to be set to zero</span>
<span class="sd"> std: the standard deviation of the normal distribution used to generate</span>
<span class="sd"> the non-zero values</span>
<span class="sd"> Examples:</span>
<span class="sd"> >>> w = torch.empty(3, 5)</span>
<span class="sd"> >>> nn.init.sparse_(w, sparsity=0.1)</span>
<span class="sd"> """</span>
<span class="k">if</span> <span class="n">tensor</span><span class="o">.</span><span class="n">ndimension</span><span class="p">()</span> <span class="o">!=</span> <span class="mi">2</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"Only tensors with 2 dimensions are supported"</span><span class="p">)</span>
<span class="n">rows</span><span class="p">,</span> <span class="n">cols</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">shape</span>
<span class="n">num_zeros</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">sparsity</span> <span class="o">*</span> <span class="n">rows</span><span class="p">))</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="n">tensor</span><span class="o">.</span><span class="n">normal_</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">std</span><span class="p">)</span>
<span class="k">for</span> <span class="n">col_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">cols</span><span class="p">):</span>
<span class="n">row_indices</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randperm</span><span class="p">(</span><span class="n">rows</span><span class="p">)</span>
<span class="n">zero_indices</span> <span class="o">=</span> <span class="n">row_indices</span><span class="p">[:</span><span class="n">num_zeros</span><span class="p">]</span>
<span class="n">tensor</span><span class="p">[</span><span class="n">zero_indices</span><span class="p">,</span> <span class="n">col_idx</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">return</span> <span class="n">tensor</span></div>
<span class="c1"># for backward compatibility</span>
<span class="k">def</span> <span class="nf">_make_deprecate</span><span class="p">(</span><span class="n">meth</span><span class="p">):</span>
<span class="n">new_name</span> <span class="o">=</span> <span class="n">meth</span><span class="o">.</span><span class="vm">__name__</span>
<span class="n">old_name</span> <span class="o">=</span> <span class="n">new_name</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
<span class="k">def</span> <span class="nf">deprecated_init</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">warnings</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span><span class="s2">"nn.init.</span><span class="si">{}</span><span class="s2"> is now deprecated in favor of nn.init.</span><span class="si">{}</span><span class="s2">."</span>
<span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">old_name</span><span class="p">,</span> <span class="n">new_name</span><span class="p">),</span> <span class="n">stacklevel</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="k">return</span> <span class="n">meth</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<span class="n">deprecated_init</span><span class="o">.</span><span class="vm">__doc__</span> <span class="o">=</span> <span class="sa">r</span><span class="s2">"""</span>
<span class="s2"> </span><span class="si">{old_name}</span><span class="s2">(...)</span>
<span class="s2"> .. warning::</span>
<span class="s2"> This method is now deprecated in favor of :func:`torch.nn.init.</span><span class="si">{new_name}</span><span class="s2">`.</span>