class ResNext(nn.Module): def __init__( self, in_channels: int, out_channels: int, layers: List[int], groups: int = 1, width_per_group: int = 64, replace_stride_with_dilation: Optional[List[bool]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None, use_pixel_shortcut=False, use_s1_block=False, num_s1_channels=None, use_max_res=False, block: Type[Bottleneck] = Bottleneck ) -> None: super(ResNext, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self._norm_layer = norm_layer self.inplanes = 64 self.dilation = 1 if replace_stride_with_dilation is None: # each element in the tuple indicates if we should replace # the 2x2 stride with a dilated convolution instead replace_stride_with_dilation = [False, False, False] if len(replace_stride_with_dilation) != 3: raise ValueError("replace_stride_with_dilation should be None " "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) self.groups = groups self.base_width = width_per_group self.use_pixel_shortcut = use_pixel_shortcut self.use_s1_block = use_s1_block self.num_s1_channels = num_s1_channels self.use_max_res = use_max_res if use_pixel_shortcut: self.pixel_shortcut = nn.Sequential( conv1x1(in_channels, 128), nn.ReLU(inplace=True), conv1x1(128, 256) ) if use_s1_block: assert num_s1_channels if in_channels - num_s1_channels > 0: # we have S2 channels self.entry_block = nn.Sequential( conv1x1(in_channels - num_s1_channels, self.inplanes // 2), norm_layer(self.inplanes // 2), nn.ReLU(inplace=True) ) self.s1_block = nn.Sequential( nn.Conv2d(num_s1_channels, self.inplanes // 2, kernel_size=5, padding=2), norm_layer(self.inplanes // 2), nn.ReLU(inplace=True) ) else: # no s2 channels self.s1_block = nn.Sequential( nn.Conv2d(num_s1_channels, self.inplanes, kernel_size=5, padding=2), norm_layer(self.inplanes), nn.ReLU(inplace=True) ) else: self.entry_block = nn.Sequential( conv1x1(in_channels, self.inplanes), norm_layer(self.inplanes), nn.ReLU(inplace=True) ) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=1, dilate=replace_stride_with_dilation[0]) self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilate=replace_stride_with_dilation[1]) self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilate=replace_stride_with_dilation[2]) in_channels_heads = 512 * block.expansion + 256 if use_pixel_shortcut else 512 * block.expansion self.mu_head = nn.Sequential( conv1x1(in_channels_heads, 512), nn.ReLU(inplace=True), conv1x1(512, out_channels) ) self.log_phi_squared_head = nn.Sequential( conv1x1(in_channels_heads, 512), nn.ReLU(inplace=True), conv1x1(512, out_channels) ) # task-dependent homoscedastic log variances (currently unused but should be kept to not break checkpoint # loading) self.log_eta_squared = nn.Parameter(torch.zeros(out_channels)) def _make_layer(self, block: Type[Bottleneck], planes: int, blocks: int, stride: int = 1, dilate: bool = False) -> nn.Sequential: norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation if dilate: self.dilation *= stride stride = 1 if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( conv1x1(self.inplanes, planes * block.expansion, stride), norm_layer(planes * block.expansion), ) layers = [block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer)] self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer)) return nn.Sequential(*layers) def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: if self.use_pixel_shortcut: pixel_shortcut = self.pixel_shortcut(x) if self.use_s1_block: x_s2, x_s1 = x[:, :-self.num_s1_channels], x[:, -self.num_s1_channels:] if hasattr(self, 'entry_block'): x_s2 = self.entry_block(x_s2) x_s1 = self.s1_block(x_s1) x = torch.cat([x_s2, x_s1], dim=1) else: x = self.entry_block(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) if self.use_pixel_shortcut: # concatenate in channel dim x = torch.cat([x, pixel_shortcut], dim=1) mean = self.mu_head(x) # mean[0] is mean, mean[1] is max-mean if self.use_max_res: mean[:, 1, :, :] = mean[:, 0, :, :] + mean[:, 1, :, :] log_variance = self.log_phi_squared_head(x) return mean, log_variance @run.train_step @run.validate_step def step(batch, model, epoch): x, y = batch #print(x.shape, y.shape) # these ground truth locations are invalid and should not be considered for the loss calculation mask = torch.isnan(y) # & (y > 59) # evidential learning if model_type.endswith('edl'): gamma, mu, alpha, beta = model(x) # loss function nll = nanmean(NIG_NLL(y, gamma, mu, alpha, beta), mask, dim=(0, 2, 3)) #reg = nanmean(NIG_Reg(y, gamma, mu, alpha), mask, dim=(0, 2, 3)) alvar = beta / (alpha - 1) epvar = alvar / mu log_var = torch.log(alvar + epvar) loss = nll # + lam * (reg - epsilon) # bayeisan learning else: mu, log_var = model(x) log_var = limit(log_var) # ensure positive heights if run['train_cfg.positive_mean']: #assert not run['data_cfg.normalize_labels'] if run['data_cfg.normalize_labels']: # should be in [-1, 1] mu = torch.tanh(limit(mu)) else: # mean, max height should be positive mu = torch.exp(limit(mu)) # loss function if train_type == 'mse': nll = nanmean(mseloss(mu, y), mask, dim=(0, 2, 3)) elif train_type == 'mae': nll = nanmean(maeloss(mu, y), mask, dim=(0, 2, 3)) elif train_type == 'nll': nll = nanmean(negative_log_likelihood(mu, log_var, y), mask, dim=(0, 2, 3)) elif train_type == 'sequential': if epoch < sequential_epoch: nll = nanmean(mseloss(mu, y), mask, dim=(0, 2, 3)) else: nll = nanmean(negative_log_likelihood(mu, log_var, y), mask, dim=(0, 2, 3)) else: raise ValueError('Invalid Train Type') loss = nll # evaluate error if run['data_cfg.normalize_labels']: # need to denormalize error = (label_unnormalize(mu) - label_unnormalize(y)).detach() else: error = (mu - y).detach() #error[error.isnan()] = 0. # check nan in predictions #mask_mu = torch.isnan(mu) #mask = mask.logical_or(mask_mu) #print(mu.shape, y.shape, mask_mu.shape, mask.shape) #print(torch.sum(mask_mu), torch.sum(mask)) #print(torch.sum(mask_mu * ~mask)) mae = nanmean(error.abs(), mask, dim=(0, 2, 3)) mse = nanmean(error ** 2, mask, dim=(0, 2, 3)) me = nanmean(error, mask, dim=(0, 2, 3)) mrmse = nanmean(torch.sqrt(error ** 2), mask, dim=(0, 2, 3)) #if torch.isnan(mae[0]): # print(mae, mse, torch.sum(mask_mu), torch.sum(mask)) # print(torch.sum(torch.isnan(error))) # print(noexist) # variance evaluation: mean(abs( empirical RMSE - RM variance ) ) mvareval = nanmean((torch.sqrt(error ** 2) - torch.sqrt(torch.exp(log_var))).abs(), mask, dim=(0, 2, 3)) #log_phi_squared_mean = nanmean(log_phi_squared, mask, dim=(0, 2, 3)).detach() log_var_mean = nanmean(log_var, mask, dim=(0, 2, 3)).detach() if run['log_cfg.log_plot']: data_y_yhat_mean = [] data_y_error_mean = [] data_y_logvar_mean = [] data_y_yhat_max = [] data_y_error_max = [] data_y_logvar_max = [] mask = mask.cpu().numpy() #print(mask.shape) BX2X15X15 for i in range(len(y)): # check center point if mask[i, 0, 8, 8]: #print(i, mask[i, 0, 8, 8]) continue else: data_y_yhat_mean.append([y[i, 0, 8, 8], mu[i, 0, 8, 8]]) data_y_error_mean.append([y[i, 0, 8, 8], error[i, 0, 8, 8]]) data_y_logvar_mean.append([y[i, 0, 8, 8], log_var[i, 0, 8, 8]]) if mask[i, 1, 8, 8]: continue else: data_y_yhat_max.append([y[i, 1, 8, 8], mu[i, 1, 8, 8]]) data_y_error_max.append([y[i, 1, 8, 8], error[i, 1, 8, 8]]) data_y_logvar_max.append([y[i, 1, 8, 8], log_var[i, 1, 8, 8]]) #print(len(data_y_yhat_mean), len(data_y_yhat_max)) table_y_yhat_mean = wandb.Table(data=data_y_yhat_mean, columns = ["target", "prediction"]) table_y_error_mean = wandb.Table(data=data_y_error_mean, columns = ["target", "error"]) table_y_logvar_mean = wandb.Table(data=data_y_logvar_mean, columns = ["target", "log_var"]) table_y_yhat_max = wandb.Table(data=data_y_yhat_max, columns = ["target", "prediction"]) table_y_error_max = wandb.Table(data=data_y_error_max, columns = ["target", "error"]) table_y_logvar_max = wandb.Table(data=data_y_logvar_max, columns = ["target", "log_var"]) return { 'loss': nll.mean(), **{'loss_' + m: nll[i] for i, m in enumerate(labels_names)}, **{'mse_' + m: mse[i] for i, m in enumerate(labels_names)}, **{'mae_' + m: mae[i] for i, m in enumerate(labels_names)}, **{'me_' + m: me[i] for i, m in enumerate(labels_names)}, #**{'log_phi_sq_' + m: log_phi_squared_mean[i] for i, m in enumerate(labels_names)}, **{'log_eta_sq_' + m: model.log_eta_squared[i] for i, m in enumerate(labels_names)}, **{'log_var_' + m: log_var_mean[i] for i, m in enumerate(labels_names)}, **{"chart_mean y vs y_pred" : wandb.plot.scatter(table_y_yhat_mean, "y", "y_pred")}, **{"chart_mean y vs error" : wandb.plot.scatter(table_y_error_mean, "y", "error")}, **{"chart_mean y vs log_var" : wandb.plot.scatter(table_y_logvar_mean, "y", "log_var")}, **{"chart_max y vs y_pred" : wandb.plot.scatter(table_y_yhat_max, "y", "y_pred")}, **{"chart_max y vs error" : wandb.plot.scatter(table_y_error_max, "y", "error")}, **{"chart_max y vs log_var" : wandb.plot.scatter(table_y_logvar_max, "y", "log_var")} } return { 'loss': loss.mean(), **{'loss_' + m: nll[i] for i, m in enumerate(labels_names)}, **{'mse_' + m: mse[i] for i, m in enumerate(labels_names)}, **{'mae_' + m: mae[i] for i, m in enumerate(labels_names)}, **{'me_' + m: me[i] for i, m in enumerate(labels_names)}, **{'mrmse_' + m: mrmse[i] for i, m in enumerate(labels_names)}, **{'mrmse_rmvar_' + m: mvareval[i] for i, m in enumerate(labels_names)}, #**{'log_phi_sq_' + m: log_phi_squared_mean[i] for i, m in enumerate(labels_names)}, **{'log_eta_sq_' + m: model.log_eta_squared[i] for i, m in enumerate(labels_names)}, **{'log_var_' + m: log_var_mean[i] for i, m in enumerate(labels_names)} } @run.train_step @run.validate_step def step(batch, model, epoch): x, y = batch #print(x.shape, y.shape) # these ground truth locations are invalid and should not be considered for the loss calculation mask = torch.isnan(y) # & (y > 59) # evidential learning if model_type.endswith('edl'): gamma, mu, alpha, beta = model(x) # loss function nll = nanmean(NIG_NLL(y, gamma, mu, alpha, beta), mask, dim=(0, 2, 3)) #reg = nanmean(NIG_Reg(y, gamma, mu, alpha), mask, dim=(0, 2, 3)) alvar = beta / (alpha - 1) epvar = alvar / mu log_var = torch.log(alvar + epvar) loss = nll # + lam * (reg - epsilon) # bayeisan learning else: mu, log_var = model(x) log_var = limit(log_var) # ensure positive heights if run['train_cfg.positive_mean']: #assert not run['data_cfg.normalize_labels'] if run['data_cfg.normalize_labels']: # should be in [-1, 1] mu = torch.tanh(limit(mu)) else: # mean, max height should be positive mu = torch.exp(limit(mu)) # loss function if train_type == 'mse': nll = nanmean(mseloss(mu, y), mask, dim=(0, 2, 3)) elif train_type == 'mae': nll = nanmean(maeloss(mu, y), mask, dim=(0, 2, 3)) elif train_type == 'nll': nll = nanmean(negative_log_likelihood(mu, log_var, y), mask, dim=(0, 2, 3)) elif train_type == 'sequential': if epoch < sequential_epoch: nll = nanmean(mseloss(mu, y), mask, dim=(0, 2, 3)) else: nll = nanmean(negative_log_likelihood(mu, log_var, y), mask, dim=(0, 2, 3)) else: raise ValueError('Invalid Train Type') loss = nll # evaluate error if run['data_cfg.normalize_labels']: # need to denormalize error = (label_unnormalize(mu) - label_unnormalize(y)).detach() else: error = (mu - y).detach() #error[error.isnan()] = 0. # check nan in predictions #mask_mu = torch.isnan(mu) #mask = mask.logical_or(mask_mu) #print(mu.shape, y.shape, mask_mu.shape, mask.shape) #print(torch.sum(mask_mu), torch.sum(mask)) #print(torch.sum(mask_mu * ~mask)) mae = nanmean(error.abs(), mask, dim=(0, 2, 3)) mse = nanmean(error ** 2, mask, dim=(0, 2, 3)) me = nanmean(error, mask, dim=(0, 2, 3)) mrmse = nanmean(torch.sqrt(error ** 2), mask, dim=(0, 2, 3)) #if torch.isnan(mae[0]): # print(mae, mse, torch.sum(mask_mu), torch.sum(mask)) # print(torch.sum(torch.isnan(error))) # print(noexist) # variance evaluation: mean(abs( empirical RMSE - RM variance ) ) mvareval = nanmean((torch.sqrt(error ** 2) - torch.sqrt(torch.exp(log_var))).abs(), mask, dim=(0, 2, 3)) #log_phi_squared_mean = nanmean(log_phi_squared, mask, dim=(0, 2, 3)).detach() log_var_mean = nanmean(log_var, mask, dim=(0, 2, 3)).detach() if run['log_cfg.log_plot']: data_y_yhat_mean = [] data_y_error_mean = [] data_y_logvar_mean = [] data_y_yhat_max = [] data_y_error_max = [] data_y_logvar_max = [] mask = mask.cpu().numpy() #print(mask.shape) BX2X15X15 for i in range(len(y)): # check center point if mask[i, 0, 8, 8]: #print(i, mask[i, 0, 8, 8]) continue else: data_y_yhat_mean.append([y[i, 0, 8, 8], mu[i, 0, 8, 8]]) data_y_error_mean.append([y[i, 0, 8, 8], error[i, 0, 8, 8]]) data_y_logvar_mean.append([y[i, 0, 8, 8], log_var[i, 0, 8, 8]]) if mask[i, 1, 8, 8]: continue else: data_y_yhat_max.append([y[i, 1, 8, 8], mu[i, 1, 8, 8]]) data_y_error_max.append([y[i, 1, 8, 8], error[i, 1, 8, 8]]) data_y_logvar_max.append([y[i, 1, 8, 8], log_var[i, 1, 8, 8]]) #print(len(data_y_yhat_mean), len(data_y_yhat_max)) table_y_yhat_mean = wandb.Table(data=data_y_yhat_mean, columns = ["target", "prediction"]) table_y_error_mean = wandb.Table(data=data_y_error_mean, columns = ["target", "error"]) table_y_logvar_mean = wandb.Table(data=data_y_logvar_mean, columns = ["target", "log_var"]) table_y_yhat_max = wandb.Table(data=data_y_yhat_max, columns = ["target", "prediction"]) table_y_error_max = wandb.Table(data=data_y_error_max, columns = ["target", "error"]) table_y_logvar_max = wandb.Table(data=data_y_logvar_max, columns = ["target", "log_var"]) return { 'loss': nll.mean(), **{'loss_' + m: nll[i] for i, m in enumerate(labels_names)}, **{'mse_' + m: mse[i] for i, m in enumerate(labels_names)}, **{'mae_' + m: mae[i] for i, m in enumerate(labels_names)}, **{'me_' + m: me[i] for i, m in enumerate(labels_names)}, #**{'log_phi_sq_' + m: log_phi_squared_mean[i] for i, m in enumerate(labels_names)}, **{'log_eta_sq_' + m: model.log_eta_squared[i] for i, m in enumerate(labels_names)}, **{'log_var_' + m: log_var_mean[i] for i, m in enumerate(labels_names)}, **{"chart_mean y vs y_pred" : wandb.plot.scatter(table_y_yhat_mean, "y", "y_pred")}, **{"chart_mean y vs error" : wandb.plot.scatter(table_y_error_mean, "y", "error")}, **{"chart_mean y vs log_var" : wandb.plot.scatter(table_y_logvar_mean, "y", "log_var")}, **{"chart_max y vs y_pred" : wandb.plot.scatter(table_y_yhat_max, "y", "y_pred")}, **{"chart_max y vs error" : wandb.plot.scatter(table_y_error_max, "y", "error")}, **{"chart_max y vs log_var" : wandb.plot.scatter(table_y_logvar_max, "y", "log_var")} } return { 'loss': loss.mean(), **{'loss_' + m: nll[i] for i, m in enumerate(labels_names)}, **{'mse_' + m: mse[i] for i, m in enumerate(labels_names)}, **{'mae_' + m: mae[i] for i, m in enumerate(labels_names)}, **{'me_' + m: me[i] for i, m in enumerate(labels_names)}, **{'mrmse_' + m: mrmse[i] for i, m in enumerate(labels_names)}, **{'mrmse_rmvar_' + m: mvareval[i] for i, m in enumerate(labels_names)}, #**{'log_phi_sq_' + m: log_phi_squared_mean[i] for i, m in enumerate(labels_names)}, **{'log_eta_sq_' + m: model.log_eta_squared[i] for i, m in enumerate(labels_names)}, **{'log_var_' + m: log_var_mean[i] for i, m in enumerate(labels_names)} } @run.configure_optimizers def configure_optimizers(model): optim = torch.optim.Adam(model.parameters(), lr=run['train_cfg.lr'], weight_decay=run['train_cfg.weight_decay']) scheduler = torch.optim.lr_scheduler.MultiStepLR(optim, **run['train_cfg.scheduler']) return (optim, scheduler)