srntt-pytorch-master/models/__init__.py
import torch.nn.init as init
from .discriminator import Discriminator, ImageDiscriminator
from .srntt import SRNTT, ContentExtractor
from .swapper import Swapper
from .vgg import VGG
__all__ = [
'Discriminator',
'ImageDiscriminator',
'SRNTT',
'ContentExtractor',
'Swapper',
'VGG'
]
def init_weights(net, init_type='normal', init_gain=0.02):
def init_func(m):
name = m.__class__.__name__ #return a CLASS name.
```
>>>class Test()
pass
>>>test = Test()
>>>test.__class__.__name__
'Test'
```
if hasattr(m, 'weight') and ('Conv' in name or 'Linear' in name):
#hasattr is used to judge whether object contian correspond attributes
```
class Coordinnate:
x = 10
y = -5
z = 0
point1 = Coordinate()
print(hasattr(point1, 'x')) #True
print(hasattr(point1, 'y')) #True
print(hasattr(point1, 'z')) #True
print(hasattr(point1, 'no')) #False
```
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, init_gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=init_gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=init_gain)
else:
raise NotImplementedError(#return abnormal
f'initialization method [{init_type}] is not implemented')
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif 'BatchNorm2d' in name:
init.normal_(m.weight.data, 1.0, init_gain)
init.constant_(m.bias.data, 0.0)
net.apply(init_func)#对net里面的每一个module都施加init_func的作用