弹性网络( Elastic Net) 多任务 Lasso回归 MultiTaskLasso

时间:2024-03-31 13:22:56

ElasticNet 是一种使用L1和L2先验作为正则化矩阵的线性回归模型.这种组合用于只有很少的权重非零的稀疏模型,比如:class:Lasso, 但是又能保持:class:Ridge 的正则化属性.我们可以使用 l1_ratio 参数来调节L1和L2的凸组合(一类特殊的线性组合)。

当多个特征和另一个特征相关的时候弹性网络非常有用。Lasso 倾向于随机选择其中一个,而弹性网络更倾向于选择两个.
在实践中,Lasso 和 Ridge 之间权衡的一个优势是它允许在循环过程(Under rotate)中继承 Ridge 的稳定性.
弹性网络的目标函数是最小化:

弹性网络( Elastic Net) 多任务 Lasso回归 MultiTaskLasso

ElasticNetCV 可以通过交叉验证来用来设置参数 alpha (弹性网络( Elastic Net) 多任务 Lasso回归 MultiTaskLasso) 和 l1_ratio (弹性网络( Elastic Net) 多任务 Lasso回归 MultiTaskLasso)

弹性网络( Elastic Net) 多任务 Lasso回归 MultiTaskLasso

弹性网络( Elastic Net) 多任务 Lasso回归 MultiTaskLasso

弹性网络( Elastic Net) 多任务 Lasso回归 MultiTaskLasso

[python] view plain copy
 弹性网络( Elastic Net) 多任务 Lasso回归 MultiTaskLasso弹性网络( Elastic Net) 多任务 Lasso回归 MultiTaskLasso
  1. print(__doc__)  
  2.   
  3. import numpy as np  
  4. import matplotlib.pyplot as plt  
  5.   
  6. from sklearn.linear_model import lasso_path, enet_path  
  7. from sklearn import datasets  
  8.   
  9. diabetes = datasets.load_diabetes()  
  10. X = diabetes.data  
  11. y = diabetes.target  
  12.   
  13. X /= X.std(axis=0)  # Standardize data (easier to set the l1_ratio parameter)  
  14.   
  15. # Compute paths  
  16.   
  17. eps = 5e-3  # the smaller it is the longer is the path  
  18.   
  19. print("Computing regularization path using the lasso...")  
  20. alphas_lasso, coefs_lasso, _ = lasso_path(X, y, eps, fit_intercept=False)  
  21.   
  22. print("Computing regularization path using the positive lasso...")  
  23. alphas_positive_lasso, coefs_positive_lasso, _ = lasso_path(  
  24.     X, y, eps, positive=True, fit_intercept=False)  
  25. print("Computing regularization path using the elastic net...")  
  26. alphas_enet, coefs_enet, _ = enet_path(  
  27.     X, y, eps=eps, l1_ratio=0.8, fit_intercept=False)  
  28.   
  29. print("Computing regularization path using the positve elastic net...")  
  30. alphas_positive_enet, coefs_positive_enet, _ = enet_path(  
  31.     X, y, eps=eps, l1_ratio=0.8, positive=True, fit_intercept=False)  
  32.   
  33. # Display results  
  34.   
  35. plt.figure(1)  
  36. ax = plt.gca()  
  37. ax.set_color_cycle(2 * ['b''r''g''c''k'])  
  38. l1 = plt.plot(-np.log10(alphas_lasso), coefs_lasso.T)  
  39. l2 = plt.plot(-np.log10(alphas_enet), coefs_enet.T, linestyle='--')  
  40.   
  41. plt.xlabel('-Log(alpha)')  
  42. plt.ylabel('coefficients')  
  43. plt.title('Lasso and Elastic-Net Paths')  
  44. plt.legend((l1[-1], l2[-1]), ('Lasso''Elastic-Net'), loc='lower left')  
  45. plt.axis('tight')  
  46.   
  47.   
  48. plt.figure(2)  
  49. ax = plt.gca()  
  50. ax.set_color_cycle(2 * ['b''r''g''c''k'])  
  51. l1 = plt.plot(-np.log10(alphas_lasso), coefs_lasso.T)  
  52. l2 = plt.plot(-np.log10(alphas_positive_lasso), coefs_positive_lasso.T,  
  53.               linestyle='--')  
  54.   
  55. plt.xlabel('-Log(alpha)')  
  56. plt.ylabel('coefficients')  
  57. plt.title('Lasso and positive Lasso')  
  58. plt.legend((l1[-1], l2[-1]), ('Lasso''positive Lasso'), loc='lower left')  
  59. plt.axis('tight')  
  60.   
  61.   
  62. plt.figure(3)  
  63. ax = plt.gca()  
  64. ax.set_color_cycle(2 * ['b''r''g''c''k'])  
  65. l1 = plt.plot(-np.log10(alphas_enet), coefs_enet.T)  
  66. l2 = plt.plot(-np.log10(alphas_positive_enet), coefs_positive_enet.T,  
  67.               linestyle='--')  
  68.   
  69. plt.xlabel('-Log(alpha)')  
  70. plt.ylabel('coefficients')  
  71. plt.title('Elastic-Net and positive Elastic-Net')  
  72. plt.legend((l1[-1], l2[-1]), ('Elastic-Net''positive Elastic-Net'),  
  73.            loc='lower left')  
  74. plt.axis('tight')  
  75. plt.show()  
弹性网络( Elastic Net) 多任务 Lasso回归 MultiTaskLasso



MultiTaskLasso 是一种估计多元回归系数的线性模型, y 是一个2D数组,形式为(n_samples,n_tasks). 其限制条件是和其他回归问题一样,是选择的特征,同样称为 tasks.

接下来的图示比较了通过使用一个简单的Lasso或者MultiTaskLasso得到的W中非零的位置。 Lasso 估计量分散着非零值而MultiTaskLasso所有的列全部是非零的。

数学表达上,它包含了一个使用 弹性网络( Elastic Net) 多任务 Lasso回归 MultiTaskLasso 弹性网络( Elastic Net) 多任务 Lasso回归 MultiTaskLasso 先验作为正则化因子。其目标函数是最小化:

弹性网络( Elastic Net) 多任务 Lasso回归 MultiTaskLasso


这里弹性网络( Elastic Net) 多任务 Lasso回归 MultiTaskLasso

MultiTaskLasso 类的实现使用了坐标下降算法来拟合系数。

弹性网络( Elastic Net) 多任务 Lasso回归 MultiTaskLasso

弹性网络( Elastic Net) 多任务 Lasso回归 MultiTaskLasso

拟合的时间序列模型

[python] view plain copy
 弹性网络( Elastic Net) 多任务 Lasso回归 MultiTaskLasso弹性网络( Elastic Net) 多任务 Lasso回归 MultiTaskLasso
  1. import matplotlib.pyplot as plt  
  2. import numpy as np  
  3.   
  4. from sklearn.linear_model import MultiTaskLasso, Lasso  
  5.   
  6. rng = np.random.RandomState(42)  
  7.   
  8. # Generate some 2D coefficients with sine waves with random frequency and phase  
  9. n_samples, n_features, n_tasks = 1003040  
  10. n_relevant_features = 5  
  11. coef = np.zeros((n_tasks, n_features))  
  12. times = np.linspace(02 * np.pi, n_tasks)  
  13. for k in range(n_relevant_features):  
  14.     coef[:, k] = np.sin((1. + rng.randn(1)) * times + 3 * rng.randn(1))  
  15.   
  16. X = rng.randn(n_samples, n_features)  
  17. Y = np.dot(X, coef.T) + rng.randn(n_samples, n_tasks)  
  18.   
  19. coef_lasso_ = np.array([Lasso(alpha=0.5).fit(X, y).coef_ for y in Y.T])  
  20. coef_multi_task_lasso_ = MultiTaskLasso(alpha=1.).fit(X, Y).coef_  
  21.   
  22. ###############################################################################  
  23. # Plot support and time series  
  24. fig = plt.figure(figsize=(85))  
  25. plt.subplot(121)  
  26. plt.spy(coef_lasso_)  
  27. plt.xlabel('Feature')  
  28. plt.ylabel('Time (or Task)')  
  29. plt.text(105'Lasso')  
  30. plt.subplot(122)  
  31. plt.spy(coef_multi_task_lasso_)  
  32. plt.xlabel('Feature')  
  33. plt.ylabel('Time (or Task)')  
  34. plt.text(105'MultiTaskLasso')  
  35. fig.suptitle('Coefficient non-zero location')  
  36.   
  37. feature_to_plot = 0  
  38. plt.figure()  
  39. plt.plot(coef[:, feature_to_plot], 'k', label='Ground truth')  
  40. plt.plot(coef_lasso_[:, feature_to_plot], 'g', label='Lasso')  
  41. plt.plot(coef_multi_task_lasso_[:, feature_to_plot],  
  42.          'r', label='MultiTaskLasso')  
  43. plt.legend(loc='upper center')  
  44. plt.axis('tight')  
  45. plt.ylim([-1.11.1])  
  46. plt.show()  

0