! [ -e /content ] && pip install -Uqq fastai # 在Colab上升级fastaiGAN
from __future__ import annotations
from fastai.basics import *
from fastai.vision.all import *::: {#cell-3 .cell 0=‘d’ 1=‘e’ 2=‘f’ 3=‘a’ 4=‘u’ 5=‘l’ 6=‘t’ 7=’_’ 8=‘e’ 9=‘x’ 10=‘p’ 11=’ ’ 12=‘视’ 13=‘觉’ 14=‘.’ 15=‘生’ 16=‘成’ 17=‘对’ 18=‘抗’ 19=‘网’ 20=‘络’}
### 默认类级别 3:::
from nbdev.showdoc import *对生成对抗网络的基本支持
GAN 代表 生成对抗网络,由 Ian Goodfellow 发明。其概念是我们同时训练两个模型:一个生成器和一个鉴别器。生成器会尝试生成与数据集中相似的新图像,而鉴别器则会尝试区分真实图像和生成器生成的图像。生成器输出图像,鉴别器输出一个数字(通常是一个概率,真实图像为 1,假图像为 0)。
我们以相互对抗的方式训练它们,具体步骤如下(或多或少):
- 冻结生成器,训练鉴别器一步:
- 获取一批真实图像(我们称之为
real) - 生成一批假图像(我们称之为
fake) - 让鉴别器评估每一批,并计算损失函数;重要的是,它会对检测到真实图像给予正向奖励,而对假图像给予惩罚
- 用这个损失的梯度更新鉴别器的权重
- 冻结鉴别器,训练生成器一步:
- 生成一批假图像
- 在其上评估鉴别器
- 返回一个损失,正向奖励鉴别器认为这些是现实图像的情况
- 用这个损失的梯度更新生成器的权重
fastai库通过GANTrainer提供支持用于训练GANs,但不包含超过基本模型的内容。
封装模块
class GANModule(Module):
"Wrapper around a `generator` and a `critic` to create a GAN."
def __init__(self,
generator:nn.Module=None, # 生成器 PyTorch 模块
critic:nn.Module=None, # 判别器 PyTorch 模块
gen_mode:None|bool=False # 是否应将GAN设置为生成器模式
):
if generator is not None: self.generator=generator
if critic is not None: self.critic =critic
store_attr('gen_mode')
def forward(self, *args):
return self.generator(*args) if self.gen_mode else self.critic(*args)
def switch(self,
gen_mode:None|bool=None # 是否应将GAN设置为生成器模式
):
"Put the module in generator mode if `gen_mode` is `True`, in critic mode otherwise."
self.gen_mode = (not self.gen_mode) if gen_mode is None else gen_mode这只是一个包含两个模型的外壳。当被调用时,它将根据gen_mode的值将输入委托给generator或critic。
show_doc(GANModule.switch)
GANModule.switch[source]
GANModule.switch(gen_mode:(None, <class 'bool'>)=None)
Put the module in generator mode if gen_mode is True, in critic mode otherwise.
| Type | Default | Details | |
|---|---|---|---|
gen_mode |
(None, bool) |
None |
Whether the GAN should be set to generator mode |
默认情况下(将 gen_mode 留空为 None),这将使模块进入另一种模式(如果它处于生成器模式,则进入评判者模式,反之亦然)。
@delegates(ConvLayer.__init__)
def basic_critic(
in_size:int, # 批评者的输入尺寸(与生成器的输出尺寸相同)
n_channels:int, # 评论者的输入通道数
n_features:int=64, # 评论中使用的特征数量
n_extra_layers:int=0, # 评论者中额外的隐藏层数量
norm_type:NormType=NormType.Batch, # 在评论者中使用的归一化类型
**kwargs
) -> nn.Sequential:
"A basic critic for images `n_channels` x `in_size` x `in_size`."
layers = [ConvLayer(n_channels, n_features, 4, 2, 1, norm_type=None, **kwargs)]
cur_size, cur_ftrs = in_size//2, n_features
layers += [ConvLayer(cur_ftrs, cur_ftrs, 3, 1, norm_type=norm_type, **kwargs) for _ in range(n_extra_layers)]
while cur_size > 4:
layers.append(ConvLayer(cur_ftrs, cur_ftrs*2, 4, 2, 1, norm_type=norm_type, **kwargs))
cur_ftrs *= 2 ; cur_size //= 2
init = kwargs.get('init', nn.init.kaiming_normal_)
layers += [init_default(nn.Conv2d(cur_ftrs, 1, 4, padding=0), init), Flatten()]
return nn.Sequential(*layers)class AddChannels(Module):
"Add `n_dim` channels at the end of the input."
def __init__(self, n_dim): self.n_dim=n_dim
def forward(self, x): return x.view(*(list(x.shape)+[1]*self.n_dim))@delegates(ConvLayer.__init__)
def basic_generator(
out_size:int, # 生成器的输出尺寸(与判别器的输入尺寸相同)
n_channels:int, # 生成器输出通道的数量
in_sz:int=100, # 生成器输入噪声向量的尺寸
n_features:int=64, # 生成器中使用的特征数量
n_extra_layers:int=0, # 生成器中额外隐藏层的数量
**kwargs
) -> nn.Sequential:
"A basic generator from `in_sz` to images `n_channels` x `out_size` x `out_size`."
cur_size, cur_ftrs = 4, n_features//2
while cur_size < out_size: cur_size *= 2; cur_ftrs *= 2
layers = [AddChannels(2), ConvLayer(in_sz, cur_ftrs, 4, 1, transpose=True, **kwargs)]
cur_size = 4
while cur_size < out_size // 2:
layers.append(ConvLayer(cur_ftrs, cur_ftrs//2, 4, 2, 1, transpose=True, **kwargs))
cur_ftrs //= 2; cur_size *= 2
layers += [ConvLayer(cur_ftrs, cur_ftrs, 3, 1, 1, transpose=True, **kwargs) for _ in range(n_extra_layers)]
layers += [nn.ConvTranspose2d(cur_ftrs, n_channels, 4, 2, 1, bias=False), nn.Tanh()]
return nn.Sequential(*layers)critic = basic_critic(64, 3)
generator = basic_generator(64, 3)
tst = GANModule(critic=critic, generator=generator)
real = torch.randn(2, 3, 64, 64)
real_p = tst(real)
test_eq(real_p.shape, [2,1])
tst.switch() #tst 现已进入生成器模式
noise = torch.randn(2, 100)
fake = tst(noise)
test_eq(fake.shape, real.shape)
tst.switch() #tst 已重新进入评论模式
fake_p = tst(fake)
test_eq(fake_p.shape, [2,1])_conv_args = dict(act_cls = partial(nn.LeakyReLU, negative_slope=0.2), norm_type=NormType.Spectral)
def _conv(ni, nf, ks=3, stride=1, self_attention=False, **kwargs):
if self_attention: kwargs['xtra'] = SelfAttention(nf)
return ConvLayer(ni, nf, ks=ks, stride=stride, **_conv_args, **kwargs)@delegates(ConvLayer)
def DenseResBlock(
nf:int, # 特征数量
norm_type:NormType=NormType.Batch, # 归一化类型
**kwargs
) -> SequentialEx:
"Resnet block of `nf` features. `conv_kwargs` are passed to `conv_layer`."
return SequentialEx(ConvLayer(nf, nf, norm_type=norm_type, **kwargs),
ConvLayer(nf, nf, norm_type=norm_type, **kwargs),
MergeLayer(dense=True))def gan_critic(
n_channels:int=3, # 评论者的输入通道数
nf:int=128, # 评论家特征数量
n_blocks:int=3, # 判别器中ResNet块的数量
p:float=0.15 # 评论家中的丢弃量
) -> nn.Sequential:
"Critic to train a `GAN`."
layers = [
_conv(n_channels, nf, ks=4, stride=2),
nn.Dropout2d(p/2),
DenseResBlock(nf, **_conv_args)]
nf *= 2 # 经过密集区块
for i in range(n_blocks):
layers += [
nn.Dropout2d(p),
_conv(nf, nf*2, ks=4, stride=2, self_attention=(i==0))]
nf *= 2
layers += [
ConvLayer(nf, 1, ks=4, bias=False, padding=0, norm_type=NormType.Spectral, act_cls=None),
Flatten()]
return nn.Sequential(*layers)class GANLoss(GANModule):
"Wrapper around `crit_loss_func` and `gen_loss_func`"
def __init__(self,
gen_loss_func:callable, # 生成器损失函数
crit_loss_func:callable, # 批评损失函数
gan_model:GANModule # GAN模型
):
super().__init__()
store_attr('gen_loss_func,crit_loss_func,gan_model')
def generator(self,
output, # 发电机输出
target # 实像
):
"Evaluate the `output` with the critic then uses `self.gen_loss_func` to evaluate how well the critic was fooled by `output`"
fake_pred = self.gan_model.critic(output)
self.gen_loss = self.gen_loss_func(fake_pred, output, target)
return self.gen_loss
def critic(self,
real_pred, # 真实图像的评论预测
input # 输入噪声向量以传递给生成器
):
"Create some `fake_pred` with the generator from `input` and compare them to `real_pred` in `self.crit_loss_func`."
fake = self.gan_model.generator(input).requires_grad_(False)
fake_pred = self.gan_model.critic(fake)
self.crit_loss = self.crit_loss_func(real_pred, fake_pred)
return self.crit_lossshow_doc(GANLoss.generator)
GANLoss.generator[source]
GANLoss.generator(output,target)
Evaluate the output with the critic then uses self.gen_loss_func to evaluate how well the critic was fooled by output
| Type | Default | Details | |
|---|---|---|---|
output |
Generator outputs | ||
target |
Real images |
show_doc(GANLoss.critic)
GANLoss.critic[source]
GANLoss.critic(real_pred,input)
Create some fake_pred with the generator from input and compare them to real_pred in self.crit_loss_func.
| Type | Default | Details | |
|---|---|---|---|
real_pred |
Critic predictions for real images | ||
input |
Input noise vector to pass into generator |
如果调用generator方法,则该损失函数期望接收生成器的output和一些target(一批真实图像)。它将使用gen_loss_func评估生成器是否成功欺骗了鉴别器。该损失函数具有以下签名
def gen_loss_func(fake_pred, output, target):
以便能够将鉴别器对output的输出(第一个参数fake_pred)与output和target结合起来(例如,如果你想将GAN损失与其他损失混合)。
如果调用critic方法,则该损失函数期望接收鉴别器给出的real_pred和一些input(馈送给生成器的噪声)。它将使用crit_loss_func评估鉴别器。该损失函数具有以下签名
def crit_loss_func(real_pred, fake_pred):
其中real_pred是鉴别器对一批真实图像的输出,而fake_pred是通过生成器从噪声生成的。
class AdaptiveLoss(Module):
"Expand the `target` to match the `output` size before applying `crit`."
def __init__(self, crit:callable): self.crit = crit
def forward(self, output:Tensor, target:Tensor):
return self.crit(output, target[:,None].expand_as(output).float())def accuracy_thresh_expand(y_pred:Tensor, y_true:Tensor, thresh:float=0.5, sigmoid:bool=True):
"Compute thresholded accuracy after expanding `y_true` to the size of `y_pred`."
if sigmoid: y_pred = y_pred.sigmoid()
return ((y_pred>thresh).byte()==y_true[:,None].expand_as(y_pred).byte()).float().mean()GAN训练的回调函数
def set_freeze_model(
m:nn.Module, # 模型冻结/解冻
rg:bool # `requires_grad` 参数。设置为 `True` 表示冻结。
):
for p in m.parameters(): p.requires_grad_(rg)class GANTrainer(Callback):
"Callback to handle GAN Training."
run_after = TrainEvalCallback
def __init__(self,
switch_eval:bool=False, # 在计算损失时是否应将模型设置为评估模式
clip:None|float=None, # 剪掉多少权重
beta:float=0.98, # 损失的指数加权平滑参数 `beta`
gen_first:bool=False, # 无论我们从生成器训练开始
show_img:bool=True, # 是否在训练过程中展示生成的示例图像
):
store_attr('switch_eval,clip,gen_first,show_img')
self.gen_loss,self.crit_loss = AvgSmoothLoss(beta=beta),AvgSmoothLoss(beta=beta)
def _set_trainable(self):
"Appropriately set the generator and critic into a trainable or loss evaluation mode based on `self.gen_mode`."
train_model = self.generator if self.gen_mode else self.critic
loss_model = self.generator if not self.gen_mode else self.critic
set_freeze_model(train_model, True)
set_freeze_model(loss_model, False)
if self.switch_eval:
train_model.train()
loss_model.eval()
def before_fit(self):
"Initialization."
self.generator,self.critic = self.model.generator,self.model.critic
self.gen_mode = self.gen_first
self.switch(self.gen_mode)
self.crit_losses,self.gen_losses = [],[]
self.gen_loss.reset() ; self.crit_loss.reset()
#self.recorder.no_val = True
#self.recorder.add_metric_names(['gen_loss', 'disc_loss'])
#self.imgs, self.titles = [], []
def before_validate(self):
"Switch in generator mode for showing results."
self.switch(gen_mode=True)
def before_batch(self):
"Clamp the weights with `self.clip` if it's not None, set the correct input/target."
if self.training and self.clip is not None:
for p in self.critic.parameters(): p.data.clamp_(-self.clip, self.clip)
if not self.gen_mode:
(self.learn.xb,self.learn.yb) = (self.yb,self.xb)
def after_batch(self):
"Record `last_loss` in the proper list."
if not self.training: return
if self.gen_mode:
self.gen_loss.accumulate(self.learn)
self.gen_losses.append(self.gen_loss.value)
self.last_gen = self.learn.to_detach(self.pred)
else:
self.crit_loss.accumulate(self.learn)
self.crit_losses.append(self.crit_loss.value)
def before_epoch(self):
"Put the critic or the generator back to eval if necessary."
self.switch(self.gen_mode)
#def after_epoch(self):
# "Show a sample image."
# if not hasattr(self, 'last_gen') or not self.show_img: return
# data = self.learn.data
# img = self.last_gen[0]
# norm = getattr(data,'norm',False)
# if norm and norm.keywords.get('do_y',False): img = data.denorm(img)
# img = data.train_ds.y.reconstruct(img)
# self.imgs.append(img)
# self.titles.append(f'Epoch {epoch}')
# pbar.show_imgs(self.imgs, self.titles)
# return add_metrics(last_metrics, [getattr(self.smoothenerG,'smooth',None),getattr(self.smoothenerC,'smooth',None)])
def switch(self, gen_mode=None):
"Switch the model and loss function, if `gen_mode` is provided, in the desired mode."
self.gen_mode = (not self.gen_mode) if gen_mode is None else gen_mode
self._set_trainable()
self.model.switch(gen_mode)
self.loss_func.switch(gen_mode)GANTrainer本身是没有用的,您需要通过以下开关之一来完成它。
class FixedGANSwitcher(Callback):
"Switcher to do `n_crit` iterations of the critic then `n_gen` iterations of the generator."
run_after = GANTrainer
def __init__(self,
n_crit:int=1, # 在切换到生成器之前,需要进行多少步的批评者训练
n_gen:int=1 # 在切换到判别器之前,生成器需要训练多少步
):
store_attr('n_crit,n_gen')
def before_train(self): self.n_c,self.n_g = 0,0
def after_batch(self):
"Switch the model if necessary."
if not self.training: return
if self.learn.gan_trainer.gen_mode:
self.n_g += 1
n_iter,n_in,n_out = self.n_gen,self.n_c,self.n_g
else:
self.n_c += 1
n_iter,n_in,n_out = self.n_crit,self.n_g,self.n_c
target = n_iter if isinstance(n_iter, int) else n_iter(n_in)
if target == n_out:
self.learn.gan_trainer.switch()
self.n_c,self.n_g = 0,0class AdaptiveGANSwitcher(Callback):
"Switcher that goes back to generator/critic when the loss goes below `gen_thresh`/`crit_thresh`."
run_after = GANTrainer
def __init__(self,
gen_thresh:None|float=None, # 发电机损耗阈值
critic_thresh:None|float=None # 评论家损失阈值
):
store_attr('gen_thresh,critic_thresh')
def after_batch(self):
"Switch the model if necessary."
if not self.training: return
if self.gan_trainer.gen_mode:
if self.gen_thresh is None or self.loss < self.gen_thresh: self.gan_trainer.switch()
else:
if self.critic_thresh is None or self.loss < self.critic_thresh: self.gan_trainer.switch()class GANDiscriminativeLR(Callback):
"`Callback` that handles multiplying the learning rate by `mult_lr` for the critic."
run_after = GANTrainer
def __init__(self, mult_lr=5.): self.mult_lr = mult_lr
def before_batch(self):
"Multiply the current lr if necessary."
if not self.learn.gan_trainer.gen_mode and self.training:
self.learn.opt.set_hyper('lr', self.learn.opt.hypers[0]['lr']*self.mult_lr)
def after_batch(self):
"Put the LR back to its value if necessary."
if not self.learn.gan_trainer.gen_mode: self.learn.opt.set_hyper('lr', self.learn.opt.hypers[0]['lr']/self.mult_lr)GAN 数据
class InvisibleTensor(TensorBase):
"TensorBase but show method does nothing"
def show(self, ctx=None, **kwargs): return ctxdef generate_noise(
fn, # 虚拟参数,以便与 `DataBlock` 兼容
size=100 # 返回噪声向量的尺寸
) -> InvisibleTensor:
"Generate noise vector."
return cast(torch.randn(size), InvisibleTensor)我们使用 generate_noise 函数生成噪声向量,以传递给生成器进行图像生成。
@typedispatch
def show_batch(x:InvisibleTensor, y:TensorImage, samples, ctxs=None, max_n=10, nrows=None, ncols=None, figsize=None, **kwargs):
if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize)
ctxs = show_batch[object](x, y, samples, ctxs=ctxs, max_n=max_n, **kwargs)
return ctxs@typedispatch
def show_results(x:InvisibleTensor, y:TensorImage, samples, outs, ctxs=None, max_n=10, nrows=None, ncols=None, figsize=None, **kwargs):
if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize)
ctxs = [b.show(ctx=c, **kwargs) for b,c,_ in zip(outs.itemgot(0),ctxs,range(max_n))]
return ctxsbs = 128
size = 64dblock = DataBlock(blocks = (TransformBlock, ImageBlock),
get_x = generate_noise,
get_items = get_image_files,
splitter = IndexSplitter([]),
item_tfms=Resize(size, method=ResizeMethod.Crop),
batch_tfms = Normalize.from_stats(torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5])))path = untar_data(URLs.LSUN_BEDROOMS)dls = dblock.dataloaders(path, path=path, bs=bs)dls.show_batch(max_n=16)
GAN 学习器
def gan_loss_from_func(
loss_gen:callable, # 生成器的损失函数。评估生成器输出图像与目标真实图像。
loss_crit:callable, # 用于评价判别器的损失函数。评估真实图像和生成图像的预测结果。
weights_gen:None|MutableSequence|tuple=None # 生成器和判别器损失函数的权重
):
"Define loss functions for a GAN from `loss_gen` and `loss_crit`."
def _loss_G(fake_pred, output, target, weights_gen=weights_gen):
ones = fake_pred.new_ones(fake_pred.shape[0])
weights_gen = ifnone(weights_gen, (1.,1.))
return weights_gen[0] * loss_crit(fake_pred, ones) + weights_gen[1] * loss_gen(output, target)
def _loss_C(real_pred, fake_pred):
ones = real_pred.new_ones (real_pred.shape[0])
zeros = fake_pred.new_zeros(fake_pred.shape[0])
return (loss_crit(real_pred, ones) + loss_crit(fake_pred, zeros)) / 2
return _loss_G, _loss_Cdef _tk_mean(fake_pred, output, target): return fake_pred.mean()
def _tk_diff(real_pred, fake_pred): return real_pred.mean() - fake_pred.mean()@delegates()
class GANLearner(Learner):
"A `Learner` suitable for GANs."
def __init__(self,
dls:DataLoaders, # 用于GAN数据的DataLoaders对象
generator:nn.Module, # 发电机模型
critic:nn.Module, # 批评模型
gen_loss_func:callable, # 生成器损失函数
crit_loss_func:callable, # 批评损失函数
switcher:Callback|None=None, # 用于在生成器和判别器训练之间切换的回调函数,默认为 `FixedGANSwitcher`。
gen_first:bool=False, # 无论我们从生成器训练开始
switch_eval:bool=True, # 在计算损失时是否应将模型设置为评估模式
show_img:bool=True, # 是否在训练过程中展示生成的示例图像
clip:None|float=None, # 剪裁权重多少
cbs:Callback|None|MutableSequence=None, # 其他回调函数
metrics:None|MutableSequence|callable=None, # 指标
**kwargs
):
gan = GANModule(generator, critic)
loss_func = GANLoss(gen_loss_func, crit_loss_func, gan)
if switcher is None: switcher = FixedGANSwitcher()
trainer = GANTrainer(clip=clip, switch_eval=switch_eval, gen_first=gen_first, show_img=show_img)
cbs = L(cbs) + L(trainer, switcher)
metrics = L(metrics) + L(*LossMetrics('gen_loss,crit_loss'))
super().__init__(dls, gan, loss_func=loss_func, cbs=cbs, metrics=metrics, **kwargs)
@classmethod
def from_learners(cls,
gen_learn:Learner, # 一个包含生成器的`Learner`对象
crit_learn:Learner, # 一个包含评价器的`学习者`对象
switcher:Callback|None=None, # 用于在生成器和判别器训练之间切换的回调函数,默认为 `FixedGANSwitcher`。
weights_gen:None|MutableSequence|tuple=None, # 生成器和判别器损失函数的权重
**kwargs
):
"Create a GAN from `learn_gen` and `learn_crit`."
losses = gan_loss_from_func(gen_learn.loss_func, crit_learn.loss_func, weights_gen=weights_gen)
return cls(gen_learn.dls, gen_learn.model, crit_learn.model, *losses, switcher=switcher, **kwargs)
@classmethod
def wgan(cls,
dls:DataLoaders, # 用于GAN数据的DataLoaders对象
generator:nn.Module, # 发电机模型
critic:nn.Module, # 批评模型
switcher:Callback|None=None, # 用于在生成器和判别器训练之间切换的回调函数,默认为 `FixedGANSwitcher(n_crit=5, n_gen=1)`。
clip:None|float=0.01, # 剪裁权重多少
switch_eval:bool=False, # 在计算损失时是否应将模型设置为评估模式
**kwargs
):
"Create a [WGAN](https://arxiv.org/abs/1701.07875) from `dls`, `generator` and `critic`."
if switcher is None: switcher = FixedGANSwitcher(n_crit=5, n_gen=1)
return cls(dls, generator, critic, _tk_mean, _tk_diff, switcher=switcher, clip=clip, switch_eval=switch_eval, **kwargs)
GANLearner.from_learners = delegates(to=GANLearner.__init__)(GANLearner.from_learners)
GANLearner.wgan = delegates(to=GANLearner.__init__)(GANLearner.wgan)show_doc(GANLearner.from_learners)
GANLearner.from_learners[source]
GANLearner.from_learners(gen_learn:Learner,crit_learn:Learner,switcher:Callback'>, None)=None,weights_gen:(None, <class 'list'>, <class 'tuple'>)=None,gen_first:bool=False,switch_eval:bool=True,show_img:bool=True,clip:(None, <class 'float'>)=None,cbs:Callback'>, None, <class 'list'>)=None,metrics:(None, <class 'list'>, <built-in function callable>)=None,loss_func=None,opt_func=Adam,lr=0.001,splitter=trainable_params,path=None,model_dir='models',wd=None,wd_bn_bias=False,train_bn=True,moms=(0.95, 0.85, 0.95))
Create a GAN from learn_gen and learn_crit.
| Type | Default | Details | |
|---|---|---|---|
gen_learn |
Learner |
A Learner object that has the generator |
|
crit_learn |
Learner |
A Learner object that has the critic |
|
switcher |
(Callback, None) |
None |
Callback for switching between generator and critic training, defaults to FixedGANSwitcher |
weights_gen |
(None, list, tuple) |
None |
Weights for the generator and critic loss function |
gen_first |
bool |
False |
No Content |
switch_eval |
bool |
True |
No Content |
show_img |
bool |
True |
No Content |
clip |
(None, float) |
None |
No Content |
cbs |
(Callback, None, list) |
None |
No Content |
metrics |
(None, list, callable) |
None |
No Content |
loss_func |
NoneType |
None |
No Content |
opt_func |
function |
<function Adam> |
No Content |
lr |
float |
0.001 |
No Content |
splitter |
function |
<function trainable_params> |
No Content |
path |
NoneType |
None |
No Content |
model_dir |
str |
models |
No Content |
wd |
NoneType |
None |
No Content |
wd_bn_bias |
bool |
False |
No Content |
train_bn |
bool |
True |
No Content |
moms |
tuple |
(0.95, 0.85, 0.95) |
No Content |
show_doc(GANLearner.wgan)
GANLearner.wgan[source]
GANLearner.wgan(dls:DataLoaders,generator:Module,critic:Module,switcher:Callback'>, None)=None,clip:(None, <class 'float'>)=0.01,switch_eval:bool=False,gen_first:bool=False,show_img:bool=True,cbs:Callback'>, None, <class 'list'>)=None,metrics:(None, <class 'list'>, <built-in function callable>)=None,loss_func=None,opt_func=Adam,lr=0.001,splitter=trainable_params,path=None,model_dir='models',wd=None,wd_bn_bias=False,train_bn=True,moms=(0.95, 0.85, 0.95))
Create a WGAN from dls, generator and critic.
| Type | Default | Details | |
|---|---|---|---|
dls |
DataLoaders |
DataLoaders object for GAN data | |
generator |
Module |
Generator model | |
critic |
Module |
Critic model | |
switcher |
(Callback, None) |
None |
Callback for switching between generator and critic training, defaults to FixedGANSwitcher(n_crit=5, n_gen=1) |
clip |
(None, float) |
0.01 |
How much to clip the weights |
switch_eval |
bool |
False |
Whether the model should be set to eval mode when calculating loss |
gen_first |
bool |
False |
No Content |
show_img |
bool |
True |
No Content |
cbs |
(Callback, None, list) |
None |
No Content |
metrics |
(None, list, callable) |
None |
No Content |
loss_func |
NoneType |
None |
No Content |
opt_func |
function |
<function Adam> |
No Content |
lr |
float |
0.001 |
No Content |
splitter |
function |
<function trainable_params> |
No Content |
path |
NoneType |
None |
No Content |
model_dir |
str |
models |
No Content |
wd |
NoneType |
None |
No Content |
wd_bn_bias |
bool |
False |
No Content |
train_bn |
bool |
True |
No Content |
moms |
tuple |
(0.95, 0.85, 0.95) |
No Content |
from fastai.callback.all import *generator = basic_generator(64, n_channels=3, n_extra_layers=1)
critic = basic_critic (64, n_channels=3, n_extra_layers=1, act_cls=partial(nn.LeakyReLU, negative_slope=0.2))learn = GANLearner.wgan(dls, generator, critic, opt_func = RMSProp)learn.recorder.train_metrics=True
learn.recorder.valid_metrics=Falselearn.fit(1, 2e-4, wd=0.)/home/tmabraham/git/fastai/fastai/callback/core.py:52: UserWarning: You are shadowing an attribute (generator) that exists in the learner. Use `self.learn.generator` to avoid this
warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
/home/tmabraham/git/fastai/fastai/callback/core.py:52: UserWarning: You are shadowing an attribute (critic) that exists in the learner. Use `self.learn.critic` to avoid this
warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
/home/tmabraham/git/fastai/fastai/callback/core.py:52: UserWarning: You are shadowing an attribute (gen_mode) that exists in the learner. Use `self.learn.gen_mode` to avoid this
warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
| epoch | train_loss | gen_loss | crit_loss | time |
|---|---|---|---|---|
| 0 | -0.815071 | 0.646809 | -1.140522 | 00:38 |
/home/tmabraham/anaconda3/envs/fastai/lib/python3.7/site-packages/fastprogress/fastprogress.py:74: UserWarning: Your generator is empty.
warn("Your generator is empty.")
learn.show_results(max_n=9, ds_idx=0)
导出 -
from nbdev import nbdev_export
nbdev_export()Converted 00_torch_core.ipynb.
Converted 01_layers.ipynb.
Converted 01a_losses.ipynb.
Converted 02_data.load.ipynb.
Converted 03_data.core.ipynb.
Converted 04_data.external.ipynb.
Converted 05_data.transforms.ipynb.
Converted 06_data.block.ipynb.
Converted 07_vision.core.ipynb.
Converted 08_vision.data.ipynb.
Converted 09_vision.augment.ipynb.
Converted 09b_vision.utils.ipynb.
Converted 09c_vision.widgets.ipynb.
Converted 10_tutorial.pets.ipynb.
Converted 10b_tutorial.albumentations.ipynb.
Converted 11_vision.models.xresnet.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_callback.core.ipynb.
Converted 13a_learner.ipynb.
Converted 13b_metrics.ipynb.
Converted 14_callback.schedule.ipynb.
Converted 14a_callback.data.ipynb.
Converted 15_callback.hook.ipynb.
Converted 15a_vision.models.unet.ipynb.
Converted 16_callback.progress.ipynb.
Converted 17_callback.tracker.ipynb.
Converted 18_callback.fp16.ipynb.
Converted 18a_callback.training.ipynb.
Converted 18b_callback.preds.ipynb.
Converted 19_callback.mixup.ipynb.
Converted 20_interpret.ipynb.
Converted 20a_distributed.ipynb.
Converted 21_vision.learner.ipynb.
Converted 22_tutorial.imagenette.ipynb.
Converted 23_tutorial.vision.ipynb.
Converted 24_tutorial.image_sequence.ipynb.
Converted 24_tutorial.siamese.ipynb.
Converted 24_vision.gan.ipynb.
Converted 30_text.core.ipynb.
Converted 31_text.data.ipynb.
Converted 32_text.models.awdlstm.ipynb.
Converted 33_text.models.core.ipynb.
Converted 34_callback.rnn.ipynb.
Converted 35_tutorial.wikitext.ipynb.
Converted 37_text.learner.ipynb.
Converted 38_tutorial.text.ipynb.
Converted 39_tutorial.transformers.ipynb.
Converted 40_tabular.core.ipynb.
Converted 41_tabular.data.ipynb.
Converted 42_tabular.model.ipynb.
Converted 43_tabular.learner.ipynb.
Converted 44_tutorial.tabular.ipynb.
Converted 45_collab.ipynb.
Converted 46_tutorial.collab.ipynb.
Converted 50_tutorial.datablock.ipynb.
Converted 60_medical.imaging.ipynb.
Converted 61_tutorial.medical_imaging.ipynb.
Converted 65_medical.text.ipynb.
Converted 70_callback.wandb.ipynb.
Converted 71_callback.tensorboard.ipynb.
Converted 72_callback.neptune.ipynb.
Converted 73_callback.captum.ipynb.
Converted 74_callback.azureml.ipynb.
Converted 97_test_utils.ipynb.
Converted 99_pytorch_doc.ipynb.
Converted dev-setup.ipynb.
Converted app_examples.ipynb.
Converted camvid.ipynb.
Converted migrating_catalyst.ipynb.
Converted migrating_ignite.ipynb.
Converted migrating_lightning.ipynb.
Converted migrating_pytorch.ipynb.
Converted migrating_pytorch_verbose.ipynb.
Converted ulmfit.ipynb.
Converted index.ipynb.
Converted index_original.ipynb.
Converted quick_start.ipynb.
Converted tutorial.ipynb.