Source for org.jfree.data.time.MovingAverage

   1: /* ===========================================================
   2:  * JFreeChart : a free chart library for the Java(tm) platform
   3:  * ===========================================================
   4:  *
   5:  * (C) Copyright 2000-2005, by Object Refinery Limited and Contributors.
   6:  *
   7:  * Project Info:  http://www.jfree.org/jfreechart/index.html
   8:  *
   9:  * This library is free software; you can redistribute it and/or modify it 
  10:  * under the terms of the GNU Lesser General Public License as published by 
  11:  * the Free Software Foundation; either version 2.1 of the License, or 
  12:  * (at your option) any later version.
  13:  *
  14:  * This library is distributed in the hope that it will be useful, but 
  15:  * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY 
  16:  * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public 
  17:  * License for more details.
  18:  *
  19:  * You should have received a copy of the GNU Lesser General Public
  20:  * License along with this library; if not, write to the Free Software
  21:  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, 
  22:  * USA.  
  23:  *
  24:  * [Java is a trademark or registered trademark of Sun Microsystems, Inc. 
  25:  * in the United States and other countries.]
  26:  *
  27:  * ------------------
  28:  * MovingAverage.java
  29:  * ------------------
  30:  * (C) Copyright 2003-2005, by Object Refinery Limited.
  31:  *
  32:  * Original Author:  David Gilbert (for Object Refinery Limited);
  33:  * Contributor(s):   Benoit Xhenseval;
  34:  *
  35:  * $Id: MovingAverage.java,v 1.5.2.1 2005/10/25 21:35:24 mungady Exp $
  36:  *
  37:  * Changes
  38:  * -------
  39:  * 28-Jan-2003 : Version 1 (DG);
  40:  * 10-Mar-2003 : Added createPointMovingAverage() method contributed by Benoit 
  41:  *               Xhenseval (DG);
  42:  * 01-Aug-2003 : Added new method for TimeSeriesCollection, and fixed bug in 
  43:  *               XYDataset method (DG);
  44:  * 15-Jul-2004 : Switched getX() with getXValue() and getY() with 
  45:  *               getYValue() (DG);
  46:  * 11-Jan-2005 : Removed deprecated code in preparation for the 1.0.0 
  47:  *               release (DG);
  48:  *
  49:  */
  50: 
  51: package org.jfree.data.time;
  52: 
  53: import org.jfree.data.xy.XYDataset;
  54: import org.jfree.data.xy.XYSeries;
  55: import org.jfree.data.xy.XYSeriesCollection;
  56: 
  57: /**
  58:  * A utility class for calculating moving averages of time series data.
  59:  */
  60: public class MovingAverage {
  61: 
  62:     /**
  63:      * Creates a new {@link TimeSeriesCollection} containing a moving average 
  64:      * series for each series in the source collection.
  65:      * 
  66:      * @param source  the source collection.
  67:      * @param suffix  the suffix added to each source series name to create the
  68:      *                corresponding moving average series name.
  69:      * @param periodCount  the number of periods in the moving average 
  70:      *                     calculation.
  71:      * @param skip  the number of initial periods to skip.
  72:      * 
  73:      * @return A collection of moving average time series.
  74:      */
  75:     public static TimeSeriesCollection createMovingAverage(
  76:         TimeSeriesCollection source, String suffix, int periodCount,
  77:         int skip) {
  78:     
  79:         // check arguments
  80:         if (source == null) {
  81:             throw new IllegalArgumentException(
  82:                 "MovingAverage.createMovingAverage() : null source."
  83:             );
  84:         }
  85: 
  86:         if (periodCount < 1) {
  87:             throw new IllegalArgumentException(
  88:                 "periodCount must be greater than or equal to 1."
  89:             );
  90:         }
  91: 
  92:         TimeSeriesCollection result = new TimeSeriesCollection();
  93:         
  94:         for (int i = 0; i < source.getSeriesCount(); i++) {
  95:             TimeSeries sourceSeries = source.getSeries(i);
  96:             TimeSeries maSeries = createMovingAverage(
  97:                 sourceSeries, sourceSeries.getKey() + suffix, periodCount, skip
  98:             );
  99:             result.addSeries(maSeries);       
 100:         }
 101:         
 102:         return result;
 103:         
 104:     }
 105:     
 106:     /**
 107:      * Creates a new {@link TimeSeries} containing moving average values for 
 108:      * the given series.  If the series is empty (contains zero items), the 
 109:      * result is an empty series.
 110:      *
 111:      * @param source  the source series.
 112:      * @param name  the name of the new series.
 113:      * @param periodCount  the number of periods used in the average 
 114:      *                     calculation.
 115:      * @param skip  the number of initial periods to skip.
 116:      *
 117:      * @return The moving average series.
 118:      */
 119:     public static TimeSeries createMovingAverage(TimeSeries source,
 120:                                                  String name,
 121:                                                  int periodCount,
 122:                                                  int skip) {
 123: 
 124:         // check arguments
 125:         if (source == null) {
 126:             throw new IllegalArgumentException("Null source.");
 127:         }
 128: 
 129:         if (periodCount < 1) {
 130:             throw new IllegalArgumentException(
 131:                 "periodCount must be greater than or equal to 1."
 132:             );
 133: 
 134:         }
 135: 
 136:         TimeSeries result = new TimeSeries(name, source.getTimePeriodClass());
 137: 
 138:         if (source.getItemCount() > 0) {
 139: 
 140:             // if the initial averaging period is to be excluded, then 
 141:             // calculate the index of the
 142:             // first data item to have an average calculated...
 143:             long firstSerial 
 144:                 = source.getDataItem(0).getPeriod().getSerialIndex() + skip;
 145: 
 146:             for (int i = source.getItemCount() - 1; i >= 0; i--) {
 147: 
 148:                 // get the current data item...
 149:                 TimeSeriesDataItem current = source.getDataItem(i);
 150:                 RegularTimePeriod period = current.getPeriod();
 151:                 long serial = period.getSerialIndex();
 152: 
 153:                 if (serial >= firstSerial) {
 154:                     // work out the average for the earlier values...
 155:                     int n = 0;
 156:                     double sum = 0.0;
 157:                     long serialLimit = period.getSerialIndex() - periodCount;
 158:                     int offset = 0;
 159:                     boolean finished = false;
 160: 
 161:                     while ((offset < periodCount) && (!finished)) {
 162:                         if ((i - offset) >= 0) {
 163:                             TimeSeriesDataItem item 
 164:                                 = source.getDataItem(i - offset);
 165:                             RegularTimePeriod p = item.getPeriod();
 166:                             Number v = item.getValue();
 167:                             long currentIndex = p.getSerialIndex();
 168:                             if (currentIndex > serialLimit) {
 169:                                 if (v != null) {
 170:                                     sum = sum + v.doubleValue();
 171:                                     n = n + 1;
 172:                                 }
 173:                             }
 174:                             else {
 175:                                 finished = true;
 176:                             }
 177:                         }
 178:                         offset = offset + 1;
 179:                     }
 180:                     if (n > 0) {
 181:                         result.add(period, sum / n);
 182:                     }
 183:                     else {
 184:                         result.add(period, null);
 185:                     }
 186:                 }
 187: 
 188:             }
 189:         }
 190: 
 191:         return result;
 192: 
 193:     }
 194: 
 195:     /**
 196:      * Creates a new {@link TimeSeries} containing moving average values for 
 197:      * the given series, calculated by number of points (irrespective of the 
 198:      * 'age' of those points).  If the series is empty (contains zero items), 
 199:      * the result is an empty series.
 200:      * <p>
 201:      * Developed by Benoit Xhenseval (www.ObjectLab.co.uk).
 202:      *
 203:      * @param source  the source series.
 204:      * @param name  the name of the new series.
 205:      * @param pointCount  the number of POINTS used in the average calculation 
 206:      *                    (not periods!)
 207:      *
 208:      * @return The moving average series.
 209:      */
 210:     public static TimeSeries createPointMovingAverage(TimeSeries source,
 211:                                                       String name, 
 212:                                                       int pointCount) {
 213: 
 214:         // check arguments
 215:         if (source == null) {
 216:             throw new IllegalArgumentException("Null 'source'.");
 217:         }
 218: 
 219:         if (pointCount < 2) {
 220:             throw new IllegalArgumentException(
 221:                 "periodCount must be greater than or equal to 2."
 222:             );
 223:         }
 224: 
 225:         TimeSeries result = new TimeSeries(name, source.getTimePeriodClass());
 226:         double rollingSumForPeriod = 0.0;
 227:         for (int i = 0; i < source.getItemCount(); i++) {
 228:             // get the current data item...
 229:             TimeSeriesDataItem current = source.getDataItem(i);
 230:             RegularTimePeriod period = current.getPeriod();
 231:             rollingSumForPeriod += current.getValue().doubleValue();
 232: 
 233:             if (i > pointCount - 1) {
 234:                 // remove the point i-periodCount out of the rolling sum.
 235:                 TimeSeriesDataItem startOfMovingAvg 
 236:                     = source.getDataItem(i - pointCount);
 237:                 rollingSumForPeriod 
 238:                     -= startOfMovingAvg.getValue().doubleValue();
 239:                 result.add(period, rollingSumForPeriod / pointCount);
 240:             }
 241:             else if (i == pointCount - 1) {
 242:                 result.add(period, rollingSumForPeriod / pointCount);
 243:             }
 244:         }
 245:         return result;
 246:     }
 247: 
 248:     /**
 249:      * Creates a new {@link XYDataset} containing the moving averages of each 
 250:      * series in the <code>source</code> dataset.
 251:      *
 252:      * @param source  the source dataset.
 253:      * @param suffix  the string to append to source series names to create 
 254:      *                target series names.
 255:      * @param period  the averaging period.
 256:      * @param skip  the length of the initial skip period.
 257:      *
 258:      * @return The dataset.
 259:      */
 260:     public static XYDataset createMovingAverage(XYDataset source, String suffix,
 261:                                                 long period, final long skip) {
 262: 
 263:         return createMovingAverage(
 264:             source, suffix, (double) period, (double) skip
 265:         );
 266:         
 267:     }
 268: 
 269: 
 270:     /**
 271:      * Creates a new {@link XYDataset} containing the moving averages of each 
 272:      * series in the <code>source</code> dataset.
 273:      *
 274:      * @param source  the source dataset.
 275:      * @param suffix  the string to append to source series names to create 
 276:      *                target series names.
 277:      * @param period  the averaging period.
 278:      * @param skip  the length of the initial skip period.
 279:      *
 280:      * @return The dataset.
 281:      */
 282:     public static XYDataset createMovingAverage(XYDataset source, String suffix,
 283:                                                 double period, double skip) {
 284: 
 285:         // check arguments
 286:         if (source == null) {
 287:             throw new IllegalArgumentException("Null source (XYDataset).");
 288:         }
 289:         
 290:         XYSeriesCollection result = new XYSeriesCollection();
 291: 
 292:         for (int i = 0; i < source.getSeriesCount(); i++) {
 293:             XYSeries s = createMovingAverage(
 294:                 source, i, source.getSeriesKey(i) + suffix, period, skip
 295:             );
 296:             result.addSeries(s);
 297:         }
 298: 
 299:         return result;
 300: 
 301:     }
 302: 
 303:     /**
 304:      * Creates a new {@link XYSeries} containing the moving averages of one 
 305:      * series in the <code>source</code> dataset.
 306:      *
 307:      * @param source  the source dataset.
 308:      * @param series  the series index (zero based).
 309:      * @param name  the name for the new series.
 310:      * @param period  the averaging period.
 311:      * @param skip  the length of the initial skip period.
 312:      *
 313:      * @return The dataset.
 314:      */
 315:     public static XYSeries createMovingAverage(XYDataset source, 
 316:                                                int series, String name,
 317:                                                double period, double skip) {
 318: 
 319:                                                
 320:         // check arguments
 321:         if (source == null) {
 322:             throw new IllegalArgumentException("Null source (XYDataset).");
 323:         }
 324: 
 325:         if (period < Double.MIN_VALUE) {
 326:             throw new IllegalArgumentException("period must be positive.");
 327: 
 328:         }
 329: 
 330:         if (skip < 0.0) {
 331:             throw new IllegalArgumentException("skip must be >= 0.0.");
 332: 
 333:         }
 334: 
 335:         XYSeries result = new XYSeries(name);
 336: 
 337:         if (source.getItemCount(series) > 0) {
 338: 
 339:             // if the initial averaging period is to be excluded, then 
 340:             // calculate the lowest x-value to have an average calculated...
 341:             double first = source.getXValue(series, 0) + skip;
 342: 
 343:             for (int i = source.getItemCount(series) - 1; i >= 0; i--) {
 344: 
 345:                 // get the current data item...
 346:                 double x = source.getXValue(series, i);
 347: 
 348:                 if (x >= first) {
 349:                     // work out the average for the earlier values...
 350:                     int n = 0;
 351:                     double sum = 0.0;
 352:                     double limit = x - period;
 353:                     int offset = 0;
 354:                     boolean finished = false;
 355: 
 356:                     while (!finished) {
 357:                         if ((i - offset) >= 0) {
 358:                             double xx = source.getXValue(series, i - offset);
 359:                             Number yy = source.getY(series, i - offset);
 360:                             if (xx > limit) {
 361:                                 if (yy != null) {
 362:                                     sum = sum + yy.doubleValue();
 363:                                     n = n + 1;
 364:                                 }
 365:                             }
 366:                             else {
 367:                                 finished = true;
 368:                             }
 369:                         }
 370:                         else {
 371:                             finished = true;
 372:                         }
 373:                         offset = offset + 1;
 374:                     }
 375:                     if (n > 0) {
 376:                         result.add(x, sum / n);
 377:                     }
 378:                     else {
 379:                         result.add(x, null);
 380:                     }
 381:                 }
 382: 
 383:             }
 384:         }
 385: 
 386:         return result;
 387: 
 388:     }
 389: 
 390: }