View Javadoc
1   /*
2    *
3    * The DbUnit Database Testing Framework
4    * Copyright (C)2002-2005, DbUnit.org
5    *
6    * This library is free software; you can redistribute it and/or
7    * modify it under the terms of the GNU Lesser General Public
8    * License as published by the Free Software Foundation; either
9    * version 2.1 of the License, or (at your option) any later version.
10   *
11   * This library is distributed in the hope that it will be useful,
12   * but WITHOUT ANY WARRANTY; without even the implied warranty of
13   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14   * Lesser General Public License for more details.
15   *
16   * You should have received a copy of the GNU Lesser General Public
17   * License along with this library; if not, write to the Free Software
18   * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
19   *
20   */
21  package org.dbunit.database;
22  
23  import java.sql.PreparedStatement;
24  import java.sql.ResultSet;
25  import java.sql.SQLException;
26  import java.util.ArrayList;
27  import java.util.HashMap;
28  import java.util.HashSet;
29  import java.util.Iterator;
30  import java.util.List;
31  import java.util.Map;
32  import java.util.Set;
33  import java.util.SortedSet;
34  import java.util.TreeSet;
35  
36  import org.apache.commons.collections.map.ListOrderedMap;
37  import org.dbunit.database.search.ForeignKeyRelationshipEdge;
38  import org.dbunit.dataset.DataSetException;
39  import org.dbunit.dataset.IDataSet;
40  import org.dbunit.dataset.ITable;
41  import org.dbunit.dataset.ITableIterator;
42  import org.dbunit.dataset.ITableMetaData;
43  import org.dbunit.dataset.filter.AbstractTableFilter;
44  import org.dbunit.util.SQLHelper;
45  import org.slf4j.Logger;
46  import org.slf4j.LoggerFactory;
47  
48  /**
49   * Filter a table given a map of the allowed rows based on primary key values.<br>
50   * It uses a depth-first algorithm (although not recursive - it might be refactored
51   * in the future) to define which rows are allowed, as well which rows are necessary
52   * (and hence allowed) because of dependencies with the allowed rows.<br>
53   * <strong>NOTE:</strong> multi-column primary keys are not supported at the moment.
54   * TODO: test cases
55   * @author Felipe Leme (dbunit@felipeal.net)
56   * @author Last changed by: $Author$
57   * @version $Revision$ $Date$
58   * @since Sep 9, 2005
59   */
60  public class PrimaryKeyFilter extends AbstractTableFilter {
61  
62      private final IDatabaseConnection connection;
63  
64      private final PkTableMap allowedPKsPerTable;
65      private final PkTableMap allowedPKsInput;
66      private final PkTableMap pksToScanPerTable;
67  
68      private final boolean reverseScan;
69  
70      protected final Logger logger = LoggerFactory.getLogger(getClass());
71  
72      // cache the primary keys
73      private final Map pkColumnPerTable = new HashMap();
74  
75      private final Map fkEdgesPerTable = new HashMap();
76      private final Map fkReverseEdgesPerTable = new HashMap();
77  
78      // name of the tables, in reverse order of dependency
79      private final List tableNames = new ArrayList();
80  
81      /**
82       * Default constructor, it takes as input a map with desired rows in a final
83       * dataset; the filter will ensure that the rows necessary by these initial rows
84       * are also allowed (and so on...).
85       * @param connection database connection
86       * @param allowedPKs map of allowed rows, based on the primary keys (key is the name
87       * of a table; value is a Set with allowed primary keys for that table)
88       * @param reverseDependency flag indicating if the rows that depend on a row should
89       * also be allowed by the filter
90       */
91      public PrimaryKeyFilter(IDatabaseConnection connection, PkTableMap allowedPKs, boolean reverseDependency) {
92          this.connection = connection;    
93          this.allowedPKsPerTable = new PkTableMap();    
94          this.allowedPKsInput = allowedPKs;
95          this.reverseScan = reverseDependency;
96  
97          // we need a deep copy here
98          this.pksToScanPerTable = new PkTableMap(allowedPKs);
99      }
100 
101     public void nodeAdded(Object node) {
102         this.tableNames.add( node );
103         if ( this.logger.isDebugEnabled() ) {
104             this.logger.debug("nodeAdded: " + node );
105         }
106     }
107 
108     public void edgeAdded(ForeignKeyRelationshipEdge edge) {
109         if ( this.logger.isDebugEnabled() ) {
110             this.logger.debug("edgeAdded: " + edge );
111         }
112         // first add it to the "direct edges"
113         String from = (String) edge.getFrom();
114         Set edges = (Set) this.fkEdgesPerTable.get(from);
115         if ( edges == null ) {
116             edges = new HashSet();
117             this.fkEdgesPerTable.put( from, edges );
118         }
119         if ( ! edges.contains(edge) ) {
120             edges.add(edge);
121         }
122 
123         // then add it to the "reverse edges"
124         String to = (String) edge.getTo();
125         edges = (Set) this.fkReverseEdgesPerTable.get(to);
126         if ( edges == null ) {
127             edges = new HashSet();
128             this.fkReverseEdgesPerTable.put( to, edges );
129         }
130         if ( ! edges.contains(edge) ) {
131             edges.add(edge);
132         }
133 
134         // finally, update the PKs cache
135         updatePkCache(to, edge);
136 
137     }
138 
139     /**
140      * @see AbstractTableFilter
141      */
142     public boolean isValidName(String tableName) throws DataSetException {
143         //    boolean isValid = this.allowedIds.containsKey(tableName);
144         //    return isValid;
145         return true;
146     }
147 
148     public ITableIterator iterator(IDataSet dataSet, boolean reversed)
149     throws DataSetException {
150         if ( this.logger.isDebugEnabled() ) {
151             this.logger.debug("Filter.iterator()" );
152         }
153         try {
154             searchPKs(dataSet);
155         } catch (SQLException e) {
156             throw new DataSetException( e );
157         }
158         return new FilterIterator(reversed ? dataSet.reverseIterator() : dataSet
159                 .iterator());
160     }
161 
162     private void searchPKs(IDataSet dataSet) throws DataSetException, SQLException {
163         logger.debug("searchPKs(dataSet={}) - start", dataSet);
164 
165         int counter = 0;
166         while ( !this.pksToScanPerTable.isEmpty() ) {
167             counter ++;
168             if ( this.logger.isDebugEnabled() ) {
169                 this.logger.debug( "RUN # " + counter );
170             }
171 
172             for( int i=this.tableNames.size()-1; i>=0; i-- ) {
173                 String tableName = (String) this.tableNames.get(i);
174                 // TODO: support multi-column PKs
175                 String pkColumn = dataSet.getTable(tableName).getTableMetaData().getPrimaryKeys()[0].getColumnName();
176                 Set tmpSet = this.pksToScanPerTable.get( tableName );
177                 if ( tmpSet != null && ! tmpSet.isEmpty() ) {
178                     Set pksToScan = new HashSet( tmpSet );
179                     if ( this.logger.isDebugEnabled() ) {
180                         this.logger.debug(  "before search: "+ tableName + "=>" + pksToScan );
181                     }
182                     scanPKs( tableName, pkColumn, pksToScan );
183                     scanReversePKs( tableName, pksToScan );
184                     allowPKs( tableName, pksToScan );
185                     removePKsToScan( tableName, pksToScan );
186                 } // if
187             } // for 
188             removeScannedTables();
189         } // while
190         if ( this.logger.isDebugEnabled() ) {
191             this.logger.debug( "Finished searchIds()" );
192         }
193     } 
194 
195     private void removeScannedTables() {
196         logger.debug("removeScannedTables() - start");
197         this.pksToScanPerTable.retainOnly(this.tableNames);
198     }
199 
200     private void allowPKs(String table, Set newAllowedPKs) {
201         logger.debug("allowPKs(table={}, newAllowedPKs={}) - start", table, newAllowedPKs);
202 
203         // then, add the new IDs, but checking if it should be allowed to add them
204         Set forcedAllowedPKs = this.allowedPKsInput.get( table );
205         if( forcedAllowedPKs == null || forcedAllowedPKs.isEmpty() ) {
206             allowedPKsPerTable.addAll(table, newAllowedPKs );
207         } else {
208             for(Iterator iterator = newAllowedPKs.iterator(); iterator.hasNext(); ) {
209                 Object id = iterator.next();
210                 if( forcedAllowedPKs.contains(id) ) {
211                     allowedPKsPerTable.add(table, id);
212                 } 
213                 else 
214                 {
215                     if ( this.logger.isDebugEnabled() ) {
216                         this.logger.debug( "Discarding id " + id + " of table " + table + 
217                         " as it was not included in the input!" );
218                     }
219                 }
220             }
221         }
222     }
223 
224     private void scanPKs( String table, String pkColumn, Set allowedIds ) throws SQLException {
225         if (logger.isDebugEnabled())
226         {
227             logger.debug("scanPKs(table={}, pkColumn={}, allowedIds={}) - start",
228                     new Object[]{ table, pkColumn, allowedIds });
229         }
230 
231         Set fkEdges = (Set) this.fkEdgesPerTable.get( table );
232         if ( fkEdges == null || fkEdges.isEmpty() ) {
233             return;
234         }
235         // we need a temporary list as there is no warranty about the set order...
236         List fkTables = new ArrayList( fkEdges.size() );
237         StringBuffer colsBuffer = new StringBuffer();
238         for(Iterator iterator = fkEdges.iterator(); iterator.hasNext(); ) {
239             ForeignKeyRelationshipEdge edge = (ForeignKeyRelationshipEdge) iterator.next();
240             fkTables.add( edge.getTo() );
241             colsBuffer.append( edge.getFKColumn() );
242             if ( iterator.hasNext() ) {
243                 colsBuffer.append( ", " );
244             }
245         }
246         // NOTE: make sure the query below is compatible standard SQL
247         String sql = "SELECT " + colsBuffer + " FROM " + table + 
248         " WHERE " + pkColumn + " = ? ";
249         if ( this.logger.isDebugEnabled() ) {
250             this.logger.debug( "SQL: " + sql );
251         }
252 
253         scanPKs(table, sql, allowedIds, fkTables);
254     }
255 
256     private void scanPKs(String table, String sql, Set allowedIds, List fkTables) throws SQLException
257     {
258         PreparedStatement pstmt = null;
259         ResultSet rs = null;
260         try {
261             pstmt = this.connection.getConnection().prepareStatement( sql );
262             for(Iterator iterator = allowedIds.iterator(); iterator.hasNext(); ) {
263                 Object pk = iterator.next(); // id being scanned
264                 if( this.logger.isDebugEnabled() ) {
265                     this.logger.debug("Executing sql for ? = " + pk );
266                 }
267                 pstmt.setObject( 1, pk );
268                 rs = pstmt.executeQuery();
269                 while( rs.next() ) {
270                     for( int i=0; i<fkTables.size(); i++ ) {
271                         String newTable = (String) fkTables.get(i);
272                         Object fk = rs.getObject(i+1);
273                         if( fk != null ) {
274                             if( this.logger.isDebugEnabled() ) {
275                                 this.logger.debug("New ID: " + newTable + "->" + fk);
276                             }
277                             addPKToScan( newTable, fk );
278                         } 
279                         else {
280                             this.logger.warn( "Found null FK for relationship  " + 
281                                     table + "=>" + newTable );
282                         }
283                     }
284                 }
285             }
286         } catch (SQLException e) {
287             logger.error("scanPKs()", e);
288         }
289         finally {
290             // new in the finally block. has been in the catch only before
291             SQLHelper.close( rs, pstmt );
292         }
293     }
294 
295     private void scanReversePKs(String table, Set pksToScan) throws SQLException {
296         logger.debug("scanReversePKs(table={}, pksToScan={}) - start", table, pksToScan);
297 
298         if ( ! this.reverseScan ) {
299             return; 
300         }
301         Set fkReverseEdges = (Set) this.fkReverseEdgesPerTable.get( table );
302         if ( fkReverseEdges == null || fkReverseEdges.isEmpty() ) {
303             return;
304         }
305         Iterator iterator = fkReverseEdges.iterator();
306         while ( iterator.hasNext() ) {
307             ForeignKeyRelationshipEdge edge = (ForeignKeyRelationshipEdge) iterator.next();
308             addReverseEdge( edge, pksToScan );
309         }
310     }
311 
312     private void addReverseEdge(ForeignKeyRelationshipEdge edge, Set idsToScan) throws SQLException {
313         logger.debug("addReverseEdge(edge={}, idsToScan=) - start", edge, idsToScan);
314 
315         String fkTable = (String) edge.getFrom();
316         String fkColumn = edge.getFKColumn();
317         String pkColumn = getPKColumn( fkTable );
318         // NOTE: make sure the query below is compatible standard SQL
319         String sql = "SELECT " + pkColumn + " FROM " + fkTable + " WHERE " + fkColumn + " = ? ";
320 
321         PreparedStatement pstmt = null;
322         ResultSet rs = null;
323         try {
324             if ( this.logger.isDebugEnabled() ) {
325                 this.logger.debug( "Preparing SQL query '" + sql + "'" );
326             }
327             pstmt = this.connection.getConnection().prepareStatement( sql );
328             for(Iterator iterator = idsToScan.iterator(); iterator.hasNext(); ) {
329                 Object pk = iterator.next();
330                 if ( this.logger.isDebugEnabled() ) {
331                     this.logger.debug( "executing query '" + sql + "' for ? = " + pk );
332                 }
333                 pstmt.setObject( 1, pk );
334                 rs = pstmt.executeQuery();
335                 while( rs.next() ) {
336                     Object fk = rs.getObject(1);
337                     addPKToScan( fkTable, fk );
338                 }
339             } 
340         } finally {
341             SQLHelper.close( rs, pstmt );
342         }
343     }
344 
345     private void updatePkCache(String table, ForeignKeyRelationshipEdge edge) {
346         logger.debug("updatePkCache(to={}, edge={}) - start", table, edge);
347 
348         Object pkTo = this.pkColumnPerTable.get(table);
349         if ( pkTo == null ) {
350             String pkColumn = edge.getPKColumn();
351             this.pkColumnPerTable.put( table, pkColumn );
352         }
353     }
354 
355     // TODO: support PKs with multiple values
356     private String getPKColumn( String table ) throws SQLException {
357         logger.debug("getPKColumn(table={}) - start", table);
358 
359         // Try to get the cached column
360         String pkColumn = (String) this.pkColumnPerTable.get( table );
361         if ( pkColumn == null ) {
362             // If the column has not been cached until now retrieve it from the database connection
363             pkColumn = SQLHelper.getPrimaryKeyColumn( this.connection.getConnection(), table );
364             this.pkColumnPerTable.put( table, pkColumn );
365         }
366         return pkColumn;
367     }
368 
369 
370     private void removePKsToScan(String table, Set ids) {
371         logger.debug("removePKsToScan(table={}, ids={}) - start", table, ids);
372 
373         Set pksToScan = this.pksToScanPerTable.get(table);
374         if ( pksToScan != null ) {
375             if ( pksToScan == ids ) {   
376                 throw new RuntimeException( "INTERNAL ERROR on removeIdsToScan() for table " + table );
377             } else {
378                 pksToScan.removeAll( ids );
379             }
380         }    
381     }
382 
383     private void addPKToScan(String table, Object pk) {
384         logger.debug("addPKToScan(table={}, pk={}) - start", table, pk);
385 
386         // first, check if it wasn't added yet
387         if(this.allowedPKsPerTable.contains(table, pk)) {
388             if ( this.logger.isDebugEnabled() ) {
389                 this.logger.debug( "Discarding already scanned id=" + pk + " for table " + table );
390             }
391             return;
392         }
393 
394         this.pksToScanPerTable.add(table, pk);
395     }
396 
397     public String toString() {
398         StringBuffer sb = new StringBuffer();
399         sb.append("tableNames=").append(tableNames);
400         sb.append(", allowedPKsInput=").append(allowedPKsInput);
401         sb.append(", allowedPKsPerTable=").append(allowedPKsPerTable);
402         sb.append(", fkEdgesPerTable=").append(fkEdgesPerTable);
403         sb.append(", fkReverseEdgesPerTable=").append(fkReverseEdgesPerTable);
404         sb.append(", pkColumnPerTable=").append(pkColumnPerTable);
405         sb.append(", pksToScanPerTable=").append(pksToScanPerTable);
406         sb.append(", reverseScan=").append(reverseScan);
407         sb.append(", connection=").append(connection);
408         return sb.toString();
409     }
410 
411 
412     private class FilterIterator implements ITableIterator {
413 
414         private final ITableIterator _iterator;
415 
416         public FilterIterator(ITableIterator iterator) {
417 
418             _iterator = iterator;
419         }
420 
421         ////////////////////////////////////////////////////////////////////////////
422         // ITableIterator interface
423 
424         public boolean next() throws DataSetException {
425             if ( logger.isDebugEnabled() ) {
426                 logger.debug("Iterator.next()" );
427             }      
428             while (_iterator.next()) {
429                 if (accept(_iterator.getTableMetaData().getTableName())) {
430                     return true;
431                 }
432             }
433             return false;
434         }
435 
436         public ITableMetaData getTableMetaData() throws DataSetException {
437             if ( logger.isDebugEnabled() ) {
438                 logger.debug("Iterator.getTableMetaData()" );
439             }      
440             return _iterator.getTableMetaData();
441         }
442 
443         public ITable getTable() throws DataSetException {
444             if ( logger.isDebugEnabled() ) {
445                 logger.debug("Iterator.getTable()" );
446             }
447             ITable table = _iterator.getTable();
448             String tableName = table.getTableMetaData().getTableName();
449             Set allowedPKs = allowedPKsPerTable.get( tableName );
450             if ( allowedPKs != null ) {
451                 return new PrimaryKeyFilteredTableWrapper(table, allowedPKs);
452             }
453             return table;
454         }
455     }
456 
457     /**
458      * Map that associates a table with a set of primary key objects.
459      * 
460      * @author gommma (gommma AT users.sourceforge.net)
461      * @author Last changed by: $Author$
462      * @version $Revision$ $Date$
463      * @since 2.3.0
464      */
465     public static class PkTableMap
466     {
467         private final ListOrderedMap pksPerTable;
468         private final Logger logger = LoggerFactory.getLogger(PkTableMap.class);
469 
470         public PkTableMap()
471         {
472             this.pksPerTable = new ListOrderedMap();
473         }
474 
475         /**
476          * Copy constructor
477          * @param allowedPKs
478          */
479         public PkTableMap(PkTableMap allowedPKs) {
480             this.pksPerTable = new ListOrderedMap(); 
481             Iterator iterator = allowedPKs.pksPerTable.entrySet().iterator();
482             while ( iterator.hasNext() ) {
483                 Map.Entry entry = (Map.Entry) iterator.next();
484                 String table = (String)entry.getKey();
485                 SortedSet pkObjectSet = (SortedSet) entry.getValue();
486                 SortedSet newSet = new TreeSet( pkObjectSet );
487                 this.pksPerTable.put( table, newSet );
488             }
489         }
490 
491         public int size() {
492             return pksPerTable.size();
493         }
494 
495         public boolean isEmpty() {
496             return pksPerTable.isEmpty();
497         }
498 
499         public boolean contains(String table, Object pkObject) {
500             Set pksPerTable = this.get(table);
501             return (pksPerTable != null && pksPerTable.contains(pkObject));
502         }
503 
504         public void remove(String tableName) {
505             this.pksPerTable.remove(tableName);
506         }
507 
508         public void put(String table, SortedSet pkObjects) {
509             this.pksPerTable.put(table, pkObjects);
510         }
511 
512         public void add(String tableName, Object pkObject) {
513             Set pksPerTable = getCreateIfNeeded(tableName);
514             pksPerTable.add(pkObject);
515         }
516 
517         public void addAll(String tableName, Set pkObjectsToAdd) {
518             Set pksPerTable = this.getCreateIfNeeded(tableName);
519             pksPerTable.addAll(pkObjectsToAdd);
520         }
521 
522         public SortedSet get(String tableName) {
523             return (SortedSet) this.pksPerTable.get(tableName);
524         }
525 
526         private SortedSet getCreateIfNeeded(String tableName){
527             SortedSet pksPerTable = this.get(tableName);
528             // Lazily create the set if it did not exist yet
529             if( pksPerTable == null ) {
530                 pksPerTable = new TreeSet();
531                 this.pksPerTable.put(tableName, pksPerTable);
532             }
533             return pksPerTable;
534         }
535 
536         public String[] getTableNames() {
537             return (String[]) this.pksPerTable.keySet().toArray(new String[0]);
538         }
539 
540         public void retainOnly(List tableNames) {
541 
542             List tablesToRemove = new ArrayList();
543             for(Iterator iterator = this.pksPerTable.entrySet().iterator(); iterator.hasNext(); ) {
544                 Map.Entry entry = (Map.Entry) iterator.next();
545                 String table = (String) entry.getKey();
546                 SortedSet pksToScan = (SortedSet) entry.getValue();
547                 boolean removeIt = pksToScan.isEmpty();
548 
549                 if ( ! tableNames.contains(table) ) {
550                     if ( this.logger.isWarnEnabled() ) {
551                         this.logger.warn("Discarding ids " + pksToScan + " of table " + table +
552                         "as this table has not been passed as input" );
553                     }
554                     removeIt = true;
555                 }
556                 if ( removeIt ) {
557                     tablesToRemove.add( table );
558                 }
559             }
560 
561             for(Iterator iterator = tablesToRemove.iterator(); iterator.hasNext(); ) {
562                 this.remove( (String)iterator.next() );
563             }
564         }
565         
566         
567         public String toString() {
568             StringBuffer sb = new StringBuffer();
569             sb.append("pKsPerTable=").append(pksPerTable);
570             return sb.toString();
571         }
572 
573     }
574 }