View Javadoc
1   /*
2    *
3    * The DbUnit Database Testing Framework
4    * Copyright (C)2002-2005, DbUnit.org
5    *
6    * This library is free software; you can redistribute it and/or
7    * modify it under the terms of the GNU Lesser General Public
8    * License as published by the Free Software Foundation; either
9    * version 2.1 of the License, or (at your option) any later version.
10   *
11   * This library is distributed in the hope that it will be useful,
12   * but WITHOUT ANY WARRANTY; without even the implied warranty of
13   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14   * Lesser General Public License for more details.
15   *
16   * You should have received a copy of the GNU Lesser General Public
17   * License along with this library; if not, write to the Free Software
18   * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
19   *
20   */
21  package org.dbunit.database;
22  
23  import java.sql.PreparedStatement;
24  import java.sql.ResultSet;
25  import java.sql.SQLException;
26  import java.util.ArrayList;
27  import java.util.Collection;
28  import java.util.HashMap;
29  import java.util.HashSet;
30  import java.util.Iterator;
31  import java.util.List;
32  import java.util.Map;
33  import java.util.Set;
34  import java.util.SortedSet;
35  import java.util.TreeSet;
36  
37  import org.apache.commons.collections.map.ListOrderedMap;
38  import org.dbunit.database.search.ForeignKeyRelationshipEdge;
39  import org.dbunit.dataset.DataSetException;
40  import org.dbunit.dataset.IDataSet;
41  import org.dbunit.dataset.ITable;
42  import org.dbunit.dataset.ITableIterator;
43  import org.dbunit.dataset.ITableMetaData;
44  import org.dbunit.dataset.filter.AbstractTableFilter;
45  import org.dbunit.util.SQLHelper;
46  import org.slf4j.Logger;
47  import org.slf4j.LoggerFactory;
48  
49  /**
50   * Filter a table given a map of the allowed rows based on primary key values.<br>
51   * It uses a depth-first algorithm (although not recursive - it might be refactored
52   * in the future) to define which rows are allowed, as well which rows are necessary
53   * (and hence allowed) because of dependencies with the allowed rows.<br>
54   * <strong>NOTE:</strong> multi-column primary keys are not supported at the moment.
55   * TODO: test cases
56   * @author Felipe Leme (dbunit@felipeal.net)
57   * @author Last changed by: $Author$
58   * @version $Revision$ $Date$
59   * @since Sep 9, 2005
60   */
61  public class PrimaryKeyFilter extends AbstractTableFilter {
62  
63      private final IDatabaseConnection connection;
64  
65      private final PkTableMap allowedPKsPerTable;
66      private final PkTableMap allowedPKsInput;
67      private final PkTableMap pksToScanPerTable;
68  
69      private final boolean reverseScan;
70  
71      protected final Logger logger = LoggerFactory.getLogger(getClass());
72  
73      // cache the primary keys
74      private final Map pkColumnPerTable = new HashMap();
75  
76      private final Map fkEdgesPerTable = new HashMap();
77      private final Map fkReverseEdgesPerTable = new HashMap();
78  
79      // name of the tables, in reverse order of dependency
80      private final List<String> tableNames = new ArrayList<String>();
81  
82      /**
83       * Default constructor, it takes as input a map with desired rows in a final
84       * dataset; the filter will ensure that the rows necessary by these initial rows
85       * are also allowed (and so on...).
86       * @param connection database connection
87       * @param allowedPKs map of allowed rows, based on the primary keys (key is the name
88       * of a table; value is a Set with allowed primary keys for that table)
89       * @param reverseDependency flag indicating if the rows that depend on a row should
90       * also be allowed by the filter
91       */
92      public PrimaryKeyFilter(IDatabaseConnection connection, PkTableMap allowedPKs, boolean reverseDependency) {
93          this.connection = connection;
94          this.allowedPKsPerTable = new PkTableMap();
95          this.allowedPKsInput = allowedPKs;
96          this.reverseScan = reverseDependency;
97  
98          // we need a deep copy here
99          this.pksToScanPerTable = new PkTableMap(allowedPKs);
100     }
101 
102     public void nodeAdded(Object node) {
103         // TODO:  nodeAdded should take a String, but because it is an inherited method,
104         //        we must keep the signature untouched.  One day, it must be fixed throughout
105         //        the class hierarchy.
106         this.tableNames.add( (String) node );
107         if ( this.logger.isDebugEnabled() ) {
108             this.logger.debug("nodeAdded: " + node );
109         }
110     }
111 
112     public void edgeAdded(ForeignKeyRelationshipEdge edge) {
113         if ( this.logger.isDebugEnabled() ) {
114             this.logger.debug("edgeAdded: " + edge );
115         }
116         // first add it to the "direct edges"
117         String from = (String) edge.getFrom();
118         Set edges = (Set) this.fkEdgesPerTable.get(from);
119         if ( edges == null ) {
120             edges = new HashSet();
121             this.fkEdgesPerTable.put( from, edges );
122         }
123         if ( ! edges.contains(edge) ) {
124             edges.add(edge);
125         }
126 
127         // then add it to the "reverse edges"
128         String to = (String) edge.getTo();
129         edges = (Set) this.fkReverseEdgesPerTable.get(to);
130         if ( edges == null ) {
131             edges = new HashSet();
132             this.fkReverseEdgesPerTable.put( to, edges );
133         }
134         if ( ! edges.contains(edge) ) {
135             edges.add(edge);
136         }
137 
138         // finally, update the PKs cache
139         updatePkCache(to, edge);
140 
141     }
142 
143     /**
144      * @see AbstractTableFilter
145      */
146     public boolean isValidName(String tableName) throws DataSetException {
147         //    boolean isValid = this.allowedIds.containsKey(tableName);
148         //    return isValid;
149         return true;
150     }
151 
152     public ITableIterator iterator(IDataSet dataSet, boolean reversed)
153     throws DataSetException {
154         if ( this.logger.isDebugEnabled() ) {
155             this.logger.debug("Filter.iterator()" );
156         }
157         try {
158             searchPKs(dataSet);
159         } catch (SQLException e) {
160             throw new DataSetException( e );
161         }
162         return new FilterIterator(reversed ? dataSet.reverseIterator() : dataSet
163                 .iterator());
164     }
165 
166     private void searchPKs(IDataSet dataSet) throws DataSetException, SQLException {
167         logger.debug("searchPKs(dataSet={}) - start", dataSet);
168 
169         int counter = 0;
170         while ( !this.pksToScanPerTable.isEmpty() ) {
171             counter ++;
172             if ( this.logger.isDebugEnabled() ) {
173                 this.logger.debug( "RUN # " + counter );
174             }
175 
176             for( int i=this.tableNames.size()-1; i>=0; i-- ) {
177                 String tableName = this.tableNames.get(i);
178                 // TODO: support multi-column PKs
179                 String pkColumn = dataSet.getTable(tableName).getTableMetaData().getPrimaryKeys()[0].getColumnName();
180                 Set tmpSet = this.pksToScanPerTable.get( tableName );
181                 if ( tmpSet != null && ! tmpSet.isEmpty() ) {
182                     Set pksToScan = new HashSet( tmpSet );
183                     if ( this.logger.isDebugEnabled() ) {
184                         this.logger.debug(  "before search: "+ tableName + "=>" + pksToScan );
185                     }
186                     scanPKs( tableName, pkColumn, pksToScan );
187                     scanReversePKs( tableName, pksToScan );
188                     allowPKs( tableName, pksToScan );
189                     removePKsToScan( tableName, pksToScan );
190                 } // if
191             } // for
192             removeScannedTables();
193         } // while
194         if ( this.logger.isDebugEnabled() ) {
195             this.logger.debug( "Finished searchIds()" );
196         }
197     }
198 
199     private void removeScannedTables() {
200         logger.debug("removeScannedTables() - start");
201         this.pksToScanPerTable.retainOnly(this.tableNames);
202     }
203 
204     private void allowPKs(String table, Set newAllowedPKs) {
205         logger.debug("allowPKs(table={}, newAllowedPKs={}) - start", table, newAllowedPKs);
206 
207         // then, add the new IDs, but checking if it should be allowed to add them
208         Set forcedAllowedPKs = this.allowedPKsInput.get( table );
209         if( forcedAllowedPKs == null || forcedAllowedPKs.isEmpty() ) {
210             allowedPKsPerTable.addAll(table, newAllowedPKs );
211         } else {
212             for(Iterator iterator = newAllowedPKs.iterator(); iterator.hasNext(); ) {
213                 Object id = iterator.next();
214                 if( forcedAllowedPKs.contains(id) ) {
215                     allowedPKsPerTable.add(table, id);
216                 }
217                 else
218                 {
219                     if ( this.logger.isDebugEnabled() ) {
220                         this.logger.debug( "Discarding id " + id + " of table " + table +
221                         " as it was not included in the input!" );
222                     }
223                 }
224             }
225         }
226     }
227 
228     private void scanPKs( String table, String pkColumn, Set allowedIds ) throws SQLException {
229         if (logger.isDebugEnabled())
230         {
231             logger.debug("scanPKs(table={}, pkColumn={}, allowedIds={}) - start",
232                     new Object[]{ table, pkColumn, allowedIds });
233         }
234 
235         Set fkEdges = (Set) this.fkEdgesPerTable.get( table );
236         if ( fkEdges == null || fkEdges.isEmpty() ) {
237             return;
238         }
239         // we need a temporary list as there is no warranty about the set order...
240         List fkTables = new ArrayList( fkEdges.size() );
241         StringBuffer colsBuffer = new StringBuffer();
242         for(Iterator iterator = fkEdges.iterator(); iterator.hasNext(); ) {
243             ForeignKeyRelationshipEdge edge = (ForeignKeyRelationshipEdge) iterator.next();
244             fkTables.add( edge.getTo() );
245             colsBuffer.append( edge.getFKColumn() );
246             if ( iterator.hasNext() ) {
247                 colsBuffer.append( ", " );
248             }
249         }
250         // NOTE: make sure the query below is compatible standard SQL
251         String sql = "SELECT " + colsBuffer + " FROM " + table +
252         " WHERE " + pkColumn + " = ? ";
253         if ( this.logger.isDebugEnabled() ) {
254             this.logger.debug( "SQL: " + sql );
255         }
256 
257         scanPKs(table, sql, allowedIds, fkTables);
258     }
259 
260     private void scanPKs(String table, String sql, Set allowedIds, List fkTables) throws SQLException
261     {
262         PreparedStatement pstmt = null;
263         ResultSet rs = null;
264         try {
265             pstmt = this.connection.getConnection().prepareStatement( sql );
266             for(Iterator iterator = allowedIds.iterator(); iterator.hasNext(); ) {
267                 Object pk = iterator.next(); // id being scanned
268                 if( this.logger.isDebugEnabled() ) {
269                     this.logger.debug("Executing sql for ? = " + pk );
270                 }
271                 pstmt.setObject( 1, pk );
272                 rs = pstmt.executeQuery();
273                 while( rs.next() ) {
274                     for( int i=0; i<fkTables.size(); i++ ) {
275                         String newTable = (String) fkTables.get(i);
276                         Object fk = rs.getObject(i+1);
277                         if( fk != null ) {
278                             if( this.logger.isDebugEnabled() ) {
279                                 this.logger.debug("New ID: " + newTable + "->" + fk);
280                             }
281                             addPKToScan( newTable, fk );
282                         }
283                         else {
284                             this.logger.warn( "Found null FK for relationship  " +
285                                     table + "=>" + newTable );
286                         }
287                     }
288                 }
289             }
290         } catch (SQLException e) {
291             logger.error("scanPKs()", e);
292         }
293         finally {
294             // new in the finally block. has been in the catch only before
295             SQLHelper.close( rs, pstmt );
296         }
297     }
298 
299     private void scanReversePKs(String table, Set pksToScan) throws SQLException {
300         logger.debug("scanReversePKs(table={}, pksToScan={}) - start", table, pksToScan);
301 
302         if ( ! this.reverseScan ) {
303             return;
304         }
305         Set fkReverseEdges = (Set) this.fkReverseEdgesPerTable.get( table );
306         if ( fkReverseEdges == null || fkReverseEdges.isEmpty() ) {
307             return;
308         }
309         Iterator iterator = fkReverseEdges.iterator();
310         while ( iterator.hasNext() ) {
311             ForeignKeyRelationshipEdge edge = (ForeignKeyRelationshipEdge) iterator.next();
312             addReverseEdge( edge, pksToScan );
313         }
314     }
315 
316     private void addReverseEdge(ForeignKeyRelationshipEdge edge, Set idsToScan) throws SQLException {
317         logger.debug("addReverseEdge(edge={}, idsToScan=) - start", edge, idsToScan);
318 
319         String fkTable = (String) edge.getFrom();
320         String fkColumn = edge.getFKColumn();
321         String pkColumn = getPKColumn( fkTable );
322         // NOTE: make sure the query below is compatible standard SQL
323         String sql = "SELECT " + pkColumn + " FROM " + fkTable + " WHERE " + fkColumn + " = ? ";
324 
325         PreparedStatement pstmt = null;
326         ResultSet rs = null;
327         try {
328             if ( this.logger.isDebugEnabled() ) {
329                 this.logger.debug( "Preparing SQL query '" + sql + "'" );
330             }
331             pstmt = this.connection.getConnection().prepareStatement( sql );
332             for(Iterator iterator = idsToScan.iterator(); iterator.hasNext(); ) {
333                 Object pk = iterator.next();
334                 if ( this.logger.isDebugEnabled() ) {
335                     this.logger.debug( "executing query '" + sql + "' for ? = " + pk );
336                 }
337                 pstmt.setObject( 1, pk );
338                 rs = pstmt.executeQuery();
339                 while( rs.next() ) {
340                     Object fk = rs.getObject(1);
341                     addPKToScan( fkTable, fk );
342                 }
343             }
344         } finally {
345             SQLHelper.close( rs, pstmt );
346         }
347     }
348 
349     private void updatePkCache(String table, ForeignKeyRelationshipEdge edge) {
350         logger.debug("updatePkCache(to={}, edge={}) - start", table, edge);
351 
352         Object pkTo = this.pkColumnPerTable.get(table);
353         if ( pkTo == null ) {
354             String pkColumn = edge.getPKColumn();
355             this.pkColumnPerTable.put( table, pkColumn );
356         }
357     }
358 
359     // TODO: support PKs with multiple values
360     private String getPKColumn( String table ) throws SQLException {
361         logger.debug("getPKColumn(table={}) - start", table);
362 
363         // Try to get the cached column
364         String pkColumn = (String) this.pkColumnPerTable.get( table );
365         if ( pkColumn == null ) {
366             // If the column has not been cached until now retrieve it from the database connection
367             pkColumn = SQLHelper.getPrimaryKeyColumn( this.connection.getConnection(), table );
368             this.pkColumnPerTable.put( table, pkColumn );
369         }
370         return pkColumn;
371     }
372 
373 
374     private void removePKsToScan(String table, Set ids) {
375         logger.debug("removePKsToScan(table={}, ids={}) - start", table, ids);
376 
377         Set pksToScan = this.pksToScanPerTable.get(table);
378         if ( pksToScan != null ) {
379             if ( pksToScan == ids ) {
380                 throw new RuntimeException( "INTERNAL ERROR on removeIdsToScan() for table " + table );
381             } else {
382                 pksToScan.removeAll( ids );
383             }
384         }
385     }
386 
387     private void addPKToScan(String table, Object pk) {
388         logger.debug("addPKToScan(table={}, pk={}) - start", table, pk);
389 
390         // first, check if it wasn't added yet
391         if(this.allowedPKsPerTable.contains(table, pk)) {
392             if ( this.logger.isDebugEnabled() ) {
393                 this.logger.debug( "Discarding already scanned id=" + pk + " for table " + table );
394             }
395             return;
396         }
397 
398         this.pksToScanPerTable.add(table, pk);
399     }
400 
401     public String toString() {
402         StringBuffer sb = new StringBuffer();
403         sb.append("tableNames=").append(tableNames);
404         sb.append(", allowedPKsInput=").append(allowedPKsInput);
405         sb.append(", allowedPKsPerTable=").append(allowedPKsPerTable);
406         sb.append(", fkEdgesPerTable=").append(fkEdgesPerTable);
407         sb.append(", fkReverseEdgesPerTable=").append(fkReverseEdgesPerTable);
408         sb.append(", pkColumnPerTable=").append(pkColumnPerTable);
409         sb.append(", pksToScanPerTable=").append(pksToScanPerTable);
410         sb.append(", reverseScan=").append(reverseScan);
411         sb.append(", connection=").append(connection);
412         return sb.toString();
413     }
414 
415 
416     private class FilterIterator implements ITableIterator {
417 
418         private final ITableIterator _iterator;
419 
420         public FilterIterator(ITableIterator iterator) {
421 
422             _iterator = iterator;
423         }
424 
425         ////////////////////////////////////////////////////////////////////////////
426         // ITableIterator interface
427 
428         public boolean next() throws DataSetException {
429             if ( logger.isDebugEnabled() ) {
430                 logger.debug("Iterator.next()" );
431             }
432             while (_iterator.next()) {
433                 if (accept(_iterator.getTableMetaData().getTableName())) {
434                     return true;
435                 }
436             }
437             return false;
438         }
439 
440         public ITableMetaData getTableMetaData() throws DataSetException {
441             if ( logger.isDebugEnabled() ) {
442                 logger.debug("Iterator.getTableMetaData()" );
443             }
444             return _iterator.getTableMetaData();
445         }
446 
447         public ITable getTable() throws DataSetException {
448             if ( logger.isDebugEnabled() ) {
449                 logger.debug("Iterator.getTable()" );
450             }
451             ITable table = _iterator.getTable();
452             String tableName = table.getTableMetaData().getTableName();
453             Set allowedPKs = allowedPKsPerTable.get( tableName );
454             if ( allowedPKs != null ) {
455                 return new PrimaryKeyFilteredTableWrapper(table, allowedPKs);
456             }
457             return table;
458         }
459     }
460 
461     /**
462      * Map that associates a table with a set of primary key objects.
463      *
464      * @author gommma (gommma AT users.sourceforge.net)
465      * @author Last changed by: $Author$
466      * @version $Revision$ $Date$
467      * @since 2.3.0
468      */
469     public static class PkTableMap
470     {
471         private final ListOrderedMap pksPerTable;
472         private final Logger logger = LoggerFactory.getLogger(PkTableMap.class);
473 
474         public PkTableMap()
475         {
476             this.pksPerTable = new ListOrderedMap();
477         }
478 
479         /**
480          * Copy constructor
481          * @param allowedPKs
482          */
483         public PkTableMap(PkTableMap allowedPKs) {
484             this.pksPerTable = new ListOrderedMap();
485             Iterator iterator = allowedPKs.pksPerTable.entrySet().iterator();
486             while ( iterator.hasNext() ) {
487                 Map.Entry entry = (Map.Entry) iterator.next();
488                 String table = (String)entry.getKey();
489                 SortedSet pkObjectSet = (SortedSet) entry.getValue();
490                 SortedSet newSet = new TreeSet( pkObjectSet );
491                 this.pksPerTable.put( table, newSet );
492             }
493         }
494 
495         public int size() {
496             return pksPerTable.size();
497         }
498 
499         public boolean isEmpty() {
500             return pksPerTable.isEmpty();
501         }
502 
503         public boolean contains(String table, Object pkObject) {
504             Set pksPerTable = this.get(table);
505             return (pksPerTable != null && pksPerTable.contains(pkObject));
506         }
507 
508         public void remove(String tableName) {
509             this.pksPerTable.remove(tableName);
510         }
511 
512         public void put(String table, SortedSet pkObjects) {
513             this.pksPerTable.put(table, pkObjects);
514         }
515 
516         public void add(String tableName, Object pkObject) {
517             Set pksPerTable = getCreateIfNeeded(tableName);
518             pksPerTable.add(pkObject);
519         }
520 
521         public void addAll(String tableName, Set pkObjectsToAdd) {
522             Set pksPerTable = this.getCreateIfNeeded(tableName);
523             pksPerTable.addAll(pkObjectsToAdd);
524         }
525 
526         public SortedSet get(String tableName) {
527             return (SortedSet) this.pksPerTable.get(tableName);
528         }
529 
530         private SortedSet getCreateIfNeeded(String tableName){
531             SortedSet pksPerTable = this.get(tableName);
532             // Lazily create the set if it did not exist yet
533             if( pksPerTable == null ) {
534                 pksPerTable = new TreeSet();
535                 this.pksPerTable.put(tableName, pksPerTable);
536             }
537             return pksPerTable;
538         }
539 
540         public String[] getTableNames() {
541             return (String[]) this.pksPerTable.keySet().toArray(new String[0]);
542         }
543 
544         public void retainOnly(List<String> tableNames) {
545 
546             List tablesToRemove = new ArrayList();
547             for(Iterator iterator = this.pksPerTable.entrySet().iterator(); iterator.hasNext(); ) {
548                 Map.Entry entry = (Map.Entry) iterator.next();
549                 String table = (String) entry.getKey();
550                 SortedSet pksToScan = (SortedSet) entry.getValue();
551                 boolean removeIt = pksToScan.isEmpty();
552 
553                 if ( ! tableNames.contains(table) ) {
554                     if ( this.logger.isWarnEnabled() ) {
555                         this.logger.warn("Discarding ids " + pksToScan + " of table " + table +
556                         "as this table has not been passed as input" );
557                     }
558                     removeIt = true;
559                 }
560                 if ( removeIt ) {
561                     tablesToRemove.add( table );
562                 }
563             }
564 
565             for(Iterator iterator = tablesToRemove.iterator(); iterator.hasNext(); ) {
566                 this.remove( (String)iterator.next() );
567             }
568         }
569 
570         public String toString() {
571             StringBuffer sb = new StringBuffer();
572             sb.append("pKsPerTable=").append(pksPerTable);
573             return sb.toString();
574         }
575 
576     }
577 
578     public void addTable(ITable table) throws AmbiguousTableNameException {
579         logger.debug("addTable() - start");
580         nodeAdded(table.getTableMetaData().getTableName());
581     }
582 
583     public void addTables(Collection<ITable> tables) throws AmbiguousTableNameException {
584         logger.debug("addTables(Collection) - start");
585         for(ITable table: tables) {
586             addTable(table);
587         }
588     }
589 
590     public void addTables(IDataSet dataSet) throws DataSetException {
591         logger.debug("addTables(IDataSet) - start");
592         ITableIterator iterator = dataSet.iterator();
593         while(iterator.next())
594             addTable(iterator.getTable());
595     }
596 }