import torch.nn.init as init

from .discriminator import Discriminator, ImageDiscriminator
from .srntt import SRNTT, ContentExtractor
from .swapper import Swapper
from .vgg import VGG

__all__ = [

def init_weights(net, init_type='normal', init_gain=0.02):
    def init_func(m):
        name = m.__class__.__name__   #return a CLASS name.
                                                    >>>class Test()
                                                    >>>test = Test()
        if hasattr(m, 'weight') and ('Conv' in name or 'Linear' in name):
        #hasattr is used  to judge whether object contian correspond attributes
        class Coordinnate:
                x = 10
                y = -5
                z = 0
        point1 = Coordinate()
        print(hasattr(point1, 'x')) #True
        print(hasattr(point1, 'y')) #True
        print(hasattr(point1, 'z')) #True
        print(hasattr(point1, 'no')) #False
            if init_type == 'normal':
                init.normal_(, 0.0, init_gain)
            elif init_type == 'xavier':
                init.xavier_normal_(, gain=init_gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(, gain=init_gain)
                raise NotImplementedError(#return abnormal
                    f'initialization method [{init_type}] is not implemented')
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(, 0.0)
        elif 'BatchNorm2d' in name:
            init.normal_(, 1.0, init_gain)
            init.constant_(, 0.0)