srntt-pytorch-master/models/__init__.py

import torch.nn.init as init

from .discriminator import Discriminator, ImageDiscriminator
from .srntt import SRNTT, ContentExtractor
from .swapper import Swapper
from .vgg import VGG

__all__ = [
    'Discriminator',
    'ImageDiscriminator',
    'SRNTT',
    'ContentExtractor',
    'Swapper',
    'VGG'
]

def init_weights(net, init_type='normal', init_gain=0.02):
    def init_func(m):
        name = m.__class__.__name__   #return a CLASS name.
                                                    ```
                                                    >>>class Test()
                                                                pass
                                                    >>>test = Test()
                                                    >>>test.__class__.__name__
                                                    'Test'
                                                    ```
        if hasattr(m, 'weight') and ('Conv' in name or 'Linear' in name):
        #hasattr is used  to judge whether object contian correspond attributes
        ```
        class Coordinnate:
                x = 10
                y = -5
                z = 0
        point1 = Coordinate()
        print(hasattr(point1, 'x')) #True
        print(hasattr(point1, 'y')) #True
        print(hasattr(point1, 'z')) #True
        print(hasattr(point1, 'no')) #False
        ```
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=init_gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=init_gain)
            else:
                raise NotImplementedError(#return abnormal
                    f'initialization method [{init_type}] is not implemented')
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif 'BatchNorm2d' in name:
            init.normal_(m.weight.data, 1.0, init_gain)
            init.constant_(m.bias.data, 0.0)

    net.apply(init_func)#对net里面的每一个module都施加init_func的作用